Source code for squadds.simulations.drivenmodal.artifacts

"""Checkpoint and progress-tracking helpers for driven-modal runs."""

from __future__ import annotations

import json
from copy import deepcopy
from pathlib import Path
from typing import Any

STAGE_NAMES = [
    "prepared",
    "rendered",
    "setup_created",
    "sweep_completed",
    "artifacts_exported",
    "postprocessed",
    "serialized",
]


def _manifest_file(manifest_path_or_dir: str | Path) -> Path:
    manifest_path = Path(manifest_path_or_dir)
    if manifest_path.is_dir():
        return manifest_path / "manifest.json"
    return manifest_path


def _write_manifest(manifest: dict[str, Any], manifest_path: Path) -> dict[str, Any]:
    manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True))
    return manifest


[docs] def initialize_run_manifest(root_dir: str | Path, run_id: str, request_payload: dict[str, Any]) -> dict[str, Any]: """Create a new run directory with an initial manifest.""" run_dir = Path(root_dir) / run_id run_dir.mkdir(parents=True, exist_ok=True) manifest_path = run_dir / "manifest.json" manifest = { "run_id": run_id, "run_dir": str(run_dir), "request_payload": request_payload, "stages": {stage_name: {"status": "pending"} for stage_name in STAGE_NAMES}, "sweep_points": [], } return _write_manifest(manifest, manifest_path)
[docs] def load_run_manifest(manifest_path_or_dir: str | Path) -> dict[str, Any]: """Load a manifest from disk.""" manifest_path = _manifest_file(manifest_path_or_dir) return json.loads(manifest_path.read_text())
[docs] def mark_stage_complete(manifest_path_or_dir: str | Path, stage_name: str) -> dict[str, Any]: """Mark a manifest stage complete and persist the change.""" manifest_path = _manifest_file(manifest_path_or_dir) manifest = load_run_manifest(manifest_path) if stage_name not in manifest["stages"]: raise ValueError(f"Unknown stage_name: {stage_name}") manifest["stages"][stage_name]["status"] = "complete" return _write_manifest(manifest, manifest_path)
[docs] def register_sweep_points(manifest_path_or_dir: str | Path, sweep_points: list[dict[str, Any]]) -> dict[str, Any]: """Merge new sweep points into a manifest without duplicating existing point ids.""" manifest_path = _manifest_file(manifest_path_or_dir) manifest = load_run_manifest(manifest_path) existing_ids = {point["point_id"] for point in manifest["sweep_points"]} for point in sweep_points: point_id = point["point_id"] if point_id in existing_ids: continue manifest["sweep_points"].append( { "point_id": point_id, "params": deepcopy(point.get("params", {})), "status": "pending", "attempt_count": 0, "error_message": "", "artifact_uri": "", } ) existing_ids.add(point_id) return _write_manifest(manifest, manifest_path)
[docs] def select_sweep_points(manifest: dict[str, Any], statuses: set[str]) -> list[dict[str, Any]]: """Select sweep points by status from an in-memory manifest.""" return [point for point in manifest["sweep_points"] if point["status"] in statuses]
def _update_sweep_point( manifest_path_or_dir: str | Path, point_id: str, *, status: str, artifact_uri: str = "", error_message: str = "", increment_attempt: bool = False, ) -> dict[str, Any]: manifest_path = _manifest_file(manifest_path_or_dir) manifest = load_run_manifest(manifest_path) for point in manifest["sweep_points"]: if point["point_id"] != point_id: continue point["status"] = status point["artifact_uri"] = artifact_uri point["error_message"] = error_message if increment_attempt: point["attempt_count"] += 1 return _write_manifest(manifest, manifest_path) raise ValueError(f"Unknown point_id: {point_id}")
[docs] def mark_sweep_point_running(manifest_path_or_dir: str | Path, point_id: str) -> dict[str, Any]: """Mark a sweep point as running and increment attempts.""" return _update_sweep_point( manifest_path_or_dir, point_id, status="running", increment_attempt=True, )
[docs] def mark_sweep_point_success( manifest_path_or_dir: str | Path, point_id: str, artifact_uri: str, ) -> dict[str, Any]: """Mark a sweep point successful and persist its artifact URI.""" return _update_sweep_point( manifest_path_or_dir, point_id, status="success", artifact_uri=artifact_uri, )
[docs] def mark_sweep_point_failed( manifest_path_or_dir: str | Path, point_id: str, error_message: str, ) -> dict[str, Any]: """Mark a sweep point failed and persist the latest error message.""" return _update_sweep_point( manifest_path_or_dir, point_id, status="failed", error_message=error_message, )