#!/usr/bin/env python3
"""Generate owned finite-beta SFINCS-JAX profile-current diagnostics.

The finite-beta coefficient ladder already compares NTX and SFINCS-JAX on
RHSMode=3 monoenergetic transport coefficients.  This script writes a separate
RHSMode=1 profile-current diagnostic on the same owned VMEC/profile contract so
that bootstrap-current residuals can be separated from coefficient
normalization, radial interpolation, and reduced closure assumptions.

The output is deliberately a diagnostic, not a parity claim: profile-current
SFINCS-JAX runs need their own pitch/velocity convergence ladder before they can
be used as a promoted current reference.
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
from scipy.constants import elementary_charge, proton_mass  # noqa: E402

from examples.owned_finite_beta_bootstrap_comparison import (  # noqa: E402
    DEFAULT_CASE,
    ProfileContract,
    _interp,
    _profile_values,
    _to_jsonable,
)
from examples.owned_geometry_neopax_dataset import (  # noqa: E402
    OwnedJaxGeometryCase,
    discover_owned_case_specs,
)
from ntx import GridSpec  # noqa: E402

OUTPUT_PREFIX = (
    ROOT / "docs" / "_static" / "owned_finite_beta_sfincs_jax_profile_current_audit"
)
WORKDIR = ROOT / "examples" / "outputs" / "owned_finite_beta_sfincs_jax_profile_current_audit"
BOOTSTRAP_JSON = ROOT / "docs" / "_static" / "owned_finite_beta_bootstrap_comparison.json"
SFINCS_JAX_ROOT = Path(
    os.environ.get("NTX_SFINCS_JAX_ROOT", "/Users/rogeriojorge/local/tests/sfincs_jax")
)

DEFAULT_GRID = GridSpec(13, 15, 8)
DEFAULT_NX = 5
DEFAULT_RHO = (1.0 / 7.0, 0.5, 13.0 / 14.0)
DEFAULT_NU_N = (8.31565e-3,)
SFINCS_JHAT_TO_AM2 = (
    elementary_charge
    * 1.0e20
    * np.sqrt(2.0 * 1.0e3 * elementary_charge / proton_mass)
)
ION_MHAT = 2.0
ELECTRON_MHAT = 1.0 / 1836.15267343
EPS = 1.0e-30


@dataclass(frozen=True)
class SfincsProfileCurrentDeck:
    case_id: str
    case_label: str
    family: str
    rho: float
    s: float
    nu_n: float
    n_hat: float
    t_hat: float
    dn_hat_dr_n: float
    dt_hat_dr_n: float
    input_path: Path
    output_path: Path
    solver_trace_path: Path
    wout_path: Path
    status: str
    seconds: float | None = None
    error: str | None = None
    current_summary: dict[str, object] | None = None

    def as_payload(self) -> dict[str, object]:
        payload = asdict(self)
        for key in ("input_path", "output_path", "solver_trace_path", "wout_path"):
            payload[key] = str(payload[key])
        return payload


def _safe_float_label(value: float) -> str:
    return f"{value:.6g}".replace("-", "m").replace("+", "").replace(".", "p")


def _select_case(
    case_specs: tuple[OwnedJaxGeometryCase, ...] | None,
    case_id: str,
) -> OwnedJaxGeometryCase:
    cases = list(case_specs if case_specs is not None else discover_owned_case_specs())
    by_id = {case.id: case for case in cases}
    if case_id not in by_id:
        raise ValueError(f"owned finite-beta case {case_id!r} was not found")
    return by_id[case_id]


def _profile_hat_values(
    rho: float,
    contract: ProfileContract,
) -> tuple[float, float, float, float]:
    profiles = _profile_values(
        np.asarray([float(rho)], dtype=float),
        contract,
        a_b=1.0,
    )
    return (
        float(np.asarray(profiles["density"], dtype=float)[0] / 1.0e20),
        float(np.asarray(profiles["temperature"], dtype=float)[0] / 1.0e3),
        float(np.asarray(profiles["d_density_dr"], dtype=float)[0] / 1.0e20),
        float(np.asarray(profiles["d_temperature_dr"], dtype=float)[0] / 1.0e3),
    )


def _sfincs_profile_input_text(
    *,
    wout_path: Path,
    rho: float,
    nu_n: float,
    n_hat: float,
    t_hat: float,
    dn_hat_dr_n: float,
    dt_hat_dr_n: float,
    grid: GridSpec,
    nx: int,
    solver_tolerance: float,
    min_bmn_to_load: float,
    collision_operator: int,
) -> str:
    # Use ion/electron ordering to match the archived profile-current SFINCS
    # convention.  The current observable is a charge-weighted species sum, so
    # this is a convention choice rather than a species-current fit.
    return f"""! Owned finite-beta SFINCS-JAX profile-current input generated by NTX.
! This RHSMode=1 deck uses the same analytic profile contract as the NTX+NEOPAX audit.

&general
  RHSMode = 1
/

&geometryParameters
  geometryScheme = 5
  equilibriumFile = "{wout_path}"
  inputRadialCoordinate = 3
  inputRadialCoordinateForGradients = 3
  rN_wish = {rho:.17g}
  VMECRadialOption = 0
  min_Bmn_to_load = {min_bmn_to_load:.17g}
/

&speciesParameters
  Zs = 1.0 -1.0
  mHats = {ION_MHAT:.17g} {ELECTRON_MHAT:.17g}
  nHats = {n_hat:.17g} {n_hat:.17g}
  THats = {t_hat:.17g} {t_hat:.17g}
  dNHatdrNs = {dn_hat_dr_n:.17g} {dn_hat_dr_n:.17g}
  dTHatdrNs = {dt_hat_dr_n:.17g} {dt_hat_dr_n:.17g}
/

&physicsParameters
  nu_n = {nu_n:.17g}
  collisionOperator = {int(collision_operator)}
  includeXDotTerm = .true.
  includeElectricFieldTermInXiDot = .true.
  useDKESExBDrift = .false.
  includePhi1 = .false.
  dPhiHatdrN = 0.0
/

&resolutionParameters
  Ntheta = {int(grid.n_theta)}
  Nzeta = {int(grid.n_zeta)}
  Nxi = {int(grid.n_xi)}
  Nx = {int(nx)}
  solverTolerance = {solver_tolerance:.17g}
/

&otherNumericalParameters
  Nxi_for_x_option = 0
/

&preconditionerOptions
/

&export_f
/
"""


def _run_sfincs_jax_profile(
    input_path: Path,
    output_path: Path,
    solver_trace_path: Path,
    *,
    timeout_s: int,
    solve_method: str | None = None,
) -> tuple[str, float | None, str | None]:
    env = os.environ.copy()
    env.setdefault("JAX_ENABLE_X64", "True")
    env.setdefault("SFINCS_JAX_GMRES_DISTRIBUTED", "0")
    env.setdefault("SFINCS_JAX_MATVEC_SHARD_AXIS", "off")
    # Keep the optimized SFINCS-JAX RHSMode=1 policy as the default.  The
    # dense-PAS cutoff remains opt-out for constrained local reruns.
    if os.environ.get("NTX_SFINCS_JAX_DISABLE_DENSE_PAS") == "1":
        env.setdefault("SFINCS_JAX_RHSMODE1_DENSE_PAS_MAX", "0")
    if SFINCS_JAX_ROOT.exists():
        env["PYTHONPATH"] = f"{SFINCS_JAX_ROOT}{os.pathsep}{env.get('PYTHONPATH', '')}"
    command = [
        sys.executable,
        "-m",
        "sfincs_jax",
        "write-output",
        "--input",
        str(input_path),
        "--out",
        str(output_path),
        "--compute-solution",
        "--solver-trace",
        str(solver_trace_path),
        "--quiet",
    ]
    if solve_method:
        command.extend(["--solve-method", str(solve_method)])
    start = time.perf_counter()
    try:
        subprocess.run(
            command,
            check=True,
            cwd=input_path.parent,
            env=env,
            timeout=timeout_s,
            capture_output=True,
            text=True,
        )
    except subprocess.CalledProcessError as exc:  # pragma: no cover - optional runtime.
        details = (exc.stderr or exc.stdout or str(exc)).strip()
        if len(details) > 2000:
            details = details[:2000] + "..."
        return (
            "failed",
            time.perf_counter() - start,
            f"CalledProcessError({exc.returncode}): {details}",
        )
    except subprocess.TimeoutExpired as exc:  # pragma: no cover - optional runtime.
        return "failed", time.perf_counter() - start, f"TimeoutExpired: {exc}"
    except Exception as exc:  # pragma: no cover - optional runtime.
        return "failed", time.perf_counter() - start, f"{type(exc).__name__}: {exc}"
    return "complete", time.perf_counter() - start, None


def _read_h5_scalar(handle: Any, key: str) -> object:
    raw = np.asarray(handle[key])
    if raw.shape == ():
        value = raw.item()
    else:
        value = raw.reshape(-1)[-1]
    if isinstance(value, bytes):
        return value.decode("utf-8")
    if isinstance(value, np.bytes_):
        return bytes(value).decode("utf-8")
    if isinstance(value, np.integer):
        return int(value)
    if isinstance(value, np.floating):
        return float(value)
    return value


def _last_scalar(values: np.ndarray) -> float:
    array = np.asarray(values, dtype=float)
    return float(array.reshape(-1)[-1])


def _summarize_profile_output(output_path: Path) -> dict[str, object] | None:
    if not output_path.exists():
        return None
    try:
        import h5py
    except ModuleNotFoundError:  # pragma: no cover - h5py is installed in CI.
        return {
            "status": "unreadable",
            "reason": "h5py is required to inspect SFINCS-JAX HDF5 outputs",
        }
    with h5py.File(output_path, "r") as handle:
        required = ("FSABjHat", "FSABjHatOverRootFSAB2")
        if not all(key in handle for key in required):
            return {"status": "missing_current_observable"}
        fsab_jhat = _last_scalar(np.asarray(handle["FSABjHat"]))
        fsab_jhat_over_root = _last_scalar(np.asarray(handle["FSABjHatOverRootFSAB2"]))
        scalars: dict[str, float | int] = {}
        for key in (
            "RHSMode",
            "collisionOperator",
            "constraintScheme",
            "Ntheta",
            "Nzeta",
            "Nxi",
            "Nx",
            "nu_n",
            "Delta",
            "alpha",
            "B0OverBBar",
            "FSABHat2",
            "rN",
            "psiN",
        ):
            if key in handle:
                raw = _read_h5_scalar(handle, key)
                if isinstance(raw, (int, float)):
                    scalars[key] = int(raw) if str(key).startswith("N") or key in {
                        "RHSMode",
                        "collisionOperator",
                        "constraintScheme",
                    } else float(raw)
        solver: dict[str, object] = {}
        for key in (
            "linearSolverMethod",
            "linearSolverAcceptanceCriterion",
        ):
            if key in handle:
                solver[key] = str(_read_h5_scalar(handle, key))
        for key in (
            "linearSolverConverged",
            "linearSolverTrueResidualConverged",
            "linearSolverAccepted",
            "linearSolverIterations",
            "linearSolverInfoCode",
            "linearSolverLeastSquaresConverged",
        ):
            if key in handle:
                raw = _read_h5_scalar(handle, key)
                if isinstance(raw, (int, float)):
                    solver[key] = int(raw)
        for key in (
            "linearSolverResidualNorm",
            "linearSolverResidualTarget",
            "linearSolverResidualTargetRatio",
            "linearSolverReportedResidualNorm",
        ):
            if key in handle:
                raw = _read_h5_scalar(handle, key)
                if isinstance(raw, (int, float)):
                    solver[key] = float(raw)
        residual = solver.get("linearSolverResidualNorm")
        target = solver.get("linearSolverResidualTarget")
        if isinstance(residual, float) and isinstance(target, float) and target > 0.0:
            solver["true_residual_over_target"] = float(residual / target)
        residual_ratio = solver.get("true_residual_over_target")
        accepted = solver.get("linearSolverAccepted")
        converged = solver.get("linearSolverTrueResidualConverged")
        if converged is None:
            converged = solver.get("linearSolverConverged")
        solver["true_residual_gate_pass"] = bool(
            accepted == 1
            and converged == 1
            and (
                not isinstance(residual_ratio, float)
                or residual_ratio <= 1.0 + 1.0e-12
            )
        )
        return {
            "status": "complete",
            "fsab_jhat": float(fsab_jhat),
            "fsab_jhat_over_root_fsab2": float(fsab_jhat_over_root),
            "current_over_root_fsab2_am2": float(fsab_jhat_over_root * SFINCS_JHAT_TO_AM2),
            "jhat_to_am2_scale": float(SFINCS_JHAT_TO_AM2),
            "scalars": scalars,
            "solver": solver,
        }


def _load_bootstrap_targets(path: Path) -> dict[str, Any] | None:
    if not path.exists():
        return None
    return json.loads(path.read_text())


def _comparison_summary(
    bootstrap_payload: dict[str, Any] | None,
    *,
    rho: float,
    current_over_root: float | None,
) -> dict[str, float] | None:
    if bootstrap_payload is None or current_over_root is None:
        return None
    comparison = bootstrap_payload.get("comparison", {})
    source_rho = np.asarray(comparison.get("rho", []), dtype=float)
    if source_rho.size == 0:
        return None
    redl = _interp(
        source_rho,
        np.asarray(comparison["redl_current_over_root_fsab2"], dtype=float),
        np.asarray([rho], dtype=float),
    )[0]
    ntx = _interp(
        source_rho,
        np.asarray(comparison["ntx_neopax_total_over_root_fsab2"], dtype=float),
        np.asarray([rho], dtype=float),
    )[0]
    return {
        "redl_current_over_root_fsab2": float(redl),
        "ntx_neopax_current_over_root_fsab2": float(ntx),
        "sfincs_jax_relative_error_vs_redl": float(
            abs(float(current_over_root) - float(redl)) / max(abs(float(redl)), EPS)
        ),
        "sfincs_jax_relative_error_vs_ntx_neopax": float(
            abs(float(current_over_root) - float(ntx)) / max(abs(float(ntx)), EPS)
        ),
        "ntx_neopax_relative_error_vs_redl": float(
            abs(float(ntx) - float(redl)) / max(abs(float(redl)), EPS)
        ),
    }


def build_payload(
    *,
    case_id: str = DEFAULT_CASE,
    case_specs: tuple[OwnedJaxGeometryCase, ...] | None = None,
    rho: tuple[float, ...] = DEFAULT_RHO,
    nu_n: tuple[float, ...] = DEFAULT_NU_N,
    contract: ProfileContract | None = None,
    grid: GridSpec = DEFAULT_GRID,
    nx: int = DEFAULT_NX,
    solver_tolerance: float = 1.0e-7,
    min_bmn_to_load: float = 1.0e-5,
    collision_operator: int = 1,
    solve_method: str | None = None,
    output_dir: Path = WORKDIR,
    run_sfincs_jax: bool = False,
    timeout_s: int = 300,
    bootstrap_json: Path = BOOTSTRAP_JSON,
) -> dict[str, Any]:
    if contract is None:
        contract = ProfileContract()
    case = _select_case(case_specs, case_id)
    output_dir = output_dir if output_dir.is_absolute() else (Path.cwd() / output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    bootstrap_payload = _load_bootstrap_targets(bootstrap_json)

    decks: list[SfincsProfileCurrentDeck] = []
    for rho_value in rho:
        n_hat, t_hat, dn_hat_dr_n, dt_hat_dr_n = _profile_hat_values(rho_value, contract)
        for nu_n_value in nu_n:
            deck_dir = (
                output_dir
                / case.id
                / f"rho_{_safe_float_label(float(rho_value))}"
                / f"nu_n_{_safe_float_label(float(nu_n_value))}"
            )
            deck_dir.mkdir(parents=True, exist_ok=True)
            input_path = deck_dir / "input.namelist"
            output_path = deck_dir / "sfincsOutput.h5"
            solver_trace_path = deck_dir / "sfincsOutput.solver_trace.json"
            input_path.write_text(
                _sfincs_profile_input_text(
                    wout_path=case.wout_path,
                    rho=float(rho_value),
                    nu_n=float(nu_n_value),
                    n_hat=n_hat,
                    t_hat=t_hat,
                    dn_hat_dr_n=dn_hat_dr_n,
                    dt_hat_dr_n=dt_hat_dr_n,
                    grid=grid,
                    nx=int(nx),
                    solver_tolerance=float(solver_tolerance),
                    min_bmn_to_load=float(min_bmn_to_load),
                    collision_operator=int(collision_operator),
                )
            )
            status = "input_written"
            seconds = None
            error = None
            if run_sfincs_jax:
                status, seconds, error = _run_sfincs_jax_profile(
                    input_path,
                    output_path,
                    solver_trace_path,
                    timeout_s=int(timeout_s),
                    solve_method=solve_method,
                )
            elif output_path.exists():
                status = "output_found"
            current_summary = _summarize_profile_output(output_path)
            if current_summary and current_summary.get("status") == "complete":
                status = "complete"
                comparison = _comparison_summary(
                    bootstrap_payload,
                    rho=float(rho_value),
                    current_over_root=float(
                        current_summary["current_over_root_fsab2_am2"]
                    ),
                )
                if comparison is not None:
                    current_summary["comparison"] = comparison
            decks.append(
                SfincsProfileCurrentDeck(
                    case_id=case.id,
                    case_label=case.label,
                    family=case.family,
                    rho=float(rho_value),
                    s=float(rho_value) ** 2,
                    nu_n=float(nu_n_value),
                    n_hat=n_hat,
                    t_hat=t_hat,
                    dn_hat_dr_n=dn_hat_dr_n,
                    dt_hat_dr_n=dt_hat_dr_n,
                    input_path=input_path,
                    output_path=output_path,
                    solver_trace_path=solver_trace_path,
                    wout_path=case.wout_path,
                    status=status,
                    seconds=seconds,
                    error=error,
                    current_summary=current_summary,
                )
            )

    completed = [
        deck
        for deck in decks
        if deck.current_summary is not None
        and deck.current_summary.get("status") == "complete"
    ]
    redl_errors = [
        float(deck.current_summary["comparison"]["sfincs_jax_relative_error_vs_redl"])
        for deck in completed
        if deck.current_summary is not None
        and isinstance(deck.current_summary.get("comparison"), dict)
    ]
    ntx_errors = [
        float(
            deck.current_summary["comparison"][
                "sfincs_jax_relative_error_vs_ntx_neopax"
            ]
        )
        for deck in completed
        if deck.current_summary is not None
        and isinstance(deck.current_summary.get("comparison"), dict)
    ]
    ntx_redl_errors = [
        float(deck.current_summary["comparison"]["ntx_neopax_relative_error_vs_redl"])
        for deck in completed
        if deck.current_summary is not None
        and isinstance(deck.current_summary.get("comparison"), dict)
    ]
    solver_residual_ratios = [
        float(deck.current_summary["solver"]["true_residual_over_target"])
        for deck in completed
        if deck.current_summary is not None
        and isinstance(deck.current_summary.get("solver"), dict)
        and isinstance(
            deck.current_summary["solver"].get("true_residual_over_target"),
            (int, float),
        )
    ]
    solver_gate_passes = [
        bool(deck.current_summary["solver"].get("true_residual_gate_pass"))
        for deck in completed
        if deck.current_summary is not None
        and isinstance(deck.current_summary.get("solver"), dict)
        and "true_residual_gate_pass" in deck.current_summary["solver"]
    ]
    solver_methods = sorted(
        {
            str(deck.current_summary["solver"]["linearSolverMethod"])
            for deck in completed
            if deck.current_summary is not None
            and isinstance(deck.current_summary.get("solver"), dict)
            and "linearSolverMethod" in deck.current_summary["solver"]
        }
    )
    payload = {
        "benchmark": "owned_finite_beta_sfincs_jax_profile_current_audit",
        "classification": "owned finite-beta RHSMode=1 profile-current diagnostic",
        "claim_scope": (
            "Writes and optionally runs SFINCS-JAX RHSMode=1 profile-current "
            "decks on the same owned finite-beta VMEC wout and analytic "
            "profiles used by the NTX+NEOPAX and Redl current audits.  This is "
            "a diagnostic for profile-current normalization and convergence; "
            "it is not a promoted parity claim until pitch/velocity/radial "
            "ladders pass."
        ),
        "case": case.as_payload(),
        "profile_contract": contract.as_payload(),
        "normalization_contract": {
            "rho_to_s": "s=rho^2",
            "species_order": "ion,electron to match the archived SFINCS profile-current convention",
            "profile_gradients": (
                "dNHatdrNs and dTHatdrNs are derivatives with respect to rN=r/a=rho; "
                "SFINCS-JAX converts them internally to d/d psiHat."
            ),
            "current_observable": "FSABjHatOverRootFSAB2 * e * 1e20 * sqrt(2*1keV/m_p)",
            "current_scale_am2": float(SFINCS_JHAT_TO_AM2),
            "trajectory": (
                "includeXDotTerm=.true., includeElectricFieldTermInXiDot=.true., "
                "useDKESExBDrift=.false., includePhi1=.false."
            ),
        },
        "sfincs_jax_contract": {
            "root": str(SFINCS_JAX_ROOT),
            "jax_enable_x64": "True unless explicitly overridden in the subprocess environment",
            "rhs_mode_1_solve_policy": (
                "write-output --compute-solution with SFINCS-JAX auto solver "
                "selection and a JSON solver-trace sidecar"
            ),
            "solver_convergence_gate": (
                "linearSolverAccepted=1, linearSolverTrueResidualConverged=1, "
                "and linearSolverResidualNorm <= linearSolverResidualTarget"
            ),
        },
        "inputs": {
            "rho": [float(value) for value in rho],
            "nu_n": [float(value) for value in nu_n],
            "grid": {
                "n_theta": int(grid.n_theta),
                "n_zeta": int(grid.n_zeta),
                "n_xi": int(grid.n_xi),
                "nx": int(nx),
            },
            "solver_tolerance": float(solver_tolerance),
            "min_bmn_to_load": float(min_bmn_to_load),
            "collision_operator": int(collision_operator),
            "solve_method": solve_method,
            "bootstrap_artifact": str(bootstrap_json),
        },
        "decks": [deck.as_payload() for deck in decks],
        "summary_metrics": {
            "deck_count": int(len(decks)),
            "completed_current_count": int(len(completed)),
            "failed_run_count": int(sum(deck.status == "failed" for deck in decks)),
            "input_written_count": int(sum(deck.status == "input_written" for deck in decks)),
            "max_sfincs_jax_relative_error_vs_redl": (
                float(np.max(redl_errors)) if redl_errors else None
            ),
            "max_sfincs_jax_relative_error_vs_ntx_neopax": (
                float(np.max(ntx_errors)) if ntx_errors else None
            ),
            "max_ntx_neopax_relative_error_vs_redl": (
                float(np.max(ntx_redl_errors)) if ntx_redl_errors else None
            ),
            "solver_methods": solver_methods,
            "completed_solver_converged_count": int(sum(solver_gate_passes)),
            "max_solver_true_residual_over_target": (
                float(np.max(solver_residual_ratios)) if solver_residual_ratios else None
            ),
            "all_completed_solver_converged": (
                bool(solver_gate_passes and all(solver_gate_passes))
                if completed
                else None
            ),
        },
        "conclusion": (
            "RHSMode=1 profile-current output is now generated from the same "
            "owned finite-beta profile and geometry contract.  It remains a "
            "diagnostic until resolution and collisionality-normalization "
            "ladders are complete; the existing finite-beta NTX+NEOPAX current "
            "stress is therefore not hidden as a promoted parity claim."
        ),
        "open_work": [
            (
                "run a pitch/velocity/radial convergence ladder before using "
                "RHSMode=1 SFINCS-JAX current as a finite-beta reference"
            ),
            (
                "align the profile-current collisionality normalization with "
                "the Redl and NTX+NEOPAX profile contract before comparing "
                "absolute current amplitudes"
            ),
            (
                "only promote a finite-beta current figure if RHSMode=1, Redl, "
                "and NTX+NEOPAX share the same geometry, profile, normalization, "
                "and converged numerical contract"
            ),
        ],
        "figure_png": str(OUTPUT_PREFIX.with_suffix(".png").relative_to(ROOT)),
        "figure_pdf": str(OUTPUT_PREFIX.with_suffix(".pdf").relative_to(ROOT)),
    }
    return _to_jsonable(payload)


def write_payload(payload: dict[str, Any], output_prefix: Path = OUTPUT_PREFIX) -> None:
    output_prefix.parent.mkdir(parents=True, exist_ok=True)
    payload = dict(payload)
    for key, suffix in (("figure_png", ".png"), ("figure_pdf", ".pdf")):
        path = output_prefix.with_suffix(suffix)
        try:
            payload[key] = str(path.relative_to(ROOT))
        except ValueError:
            payload[key] = str(path)
    output_prefix.with_suffix(".json").write_text(json.dumps(payload, indent=2) + "\n")


def build_figure(payload: dict[str, Any], output_prefix: Path = OUTPUT_PREFIX) -> None:
    decks = payload["decks"]
    completed = [
        deck
        for deck in decks
        if isinstance(deck.get("current_summary"), dict)
        and deck["current_summary"].get("status") == "complete"
    ]

    plt.style.use("default")
    plt.rcParams.update(
        {
            "figure.dpi": 220,
            "font.size": 10.0,
            "axes.grid": True,
            "grid.alpha": 0.24,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "legend.frameon": False,
        }
    )
    fig, axes = plt.subplots(1, 2, figsize=(12.0, 4.4), constrained_layout=True)
    ax_current, ax_error = axes

    if completed:
        rho_values = np.asarray([float(deck["rho"]) for deck in completed], dtype=float)
        order = np.argsort(rho_values)
        rho_values = rho_values[order]
        sfincs = np.asarray(
            [
                float(deck["current_summary"]["current_over_root_fsab2_am2"])
                for deck in completed
            ],
            dtype=float,
        )[order]
        redl_values = []
        ntx_values = []
        error_vs_redl = []
        error_vs_ntx = []
        ntx_error_vs_redl = []
        for deck in np.asarray(completed, dtype=object)[order]:
            comparison = deck.get("current_summary", {}).get("comparison")
            if isinstance(comparison, dict):
                redl_values.append(float(comparison["redl_current_over_root_fsab2"]))
                ntx_values.append(float(comparison["ntx_neopax_current_over_root_fsab2"]))
                error_vs_redl.append(
                    float(comparison["sfincs_jax_relative_error_vs_redl"])
                )
                error_vs_ntx.append(
                    float(comparison["sfincs_jax_relative_error_vs_ntx_neopax"])
                )
                ntx_error_vs_redl.append(
                    float(comparison["ntx_neopax_relative_error_vs_redl"])
                )
            else:
                redl_values.append(np.nan)
                ntx_values.append(np.nan)
                error_vs_redl.append(np.nan)
                error_vs_ntx.append(np.nan)
                ntx_error_vs_redl.append(np.nan)
        redl_array = np.asarray(redl_values, dtype=float)
        ntx_array = np.asarray(ntx_values, dtype=float)
        ax_current.plot(
            rho_values,
            sfincs / 1.0e6,
            marker="o",
            lw=2.0,
            color="#d55e00",
            label="SFINCS-JAX RHSMode=1",
        )
        if np.any(np.isfinite(redl_array)):
            ax_current.plot(
                rho_values,
                redl_array / 1.0e6,
                marker="s",
                lw=1.7,
                color="#009e73",
                label="Redl target",
            )
            ax_current.plot(
                rho_values,
                ntx_array / 1.0e6,
                marker="^",
                lw=1.7,
                color="#0072b2",
                label="NTX+NEOPAX",
            )
            ax_error.semilogy(
                rho_values,
                error_vs_redl,
                marker="o",
                lw=1.8,
                color="#d55e00",
                label="SFINCS-JAX vs Redl",
            )
            ax_error.semilogy(
                rho_values,
                error_vs_ntx,
                marker="s",
                lw=1.6,
                color="#0072b2",
                label="SFINCS-JAX vs NTX+NEOPAX",
            )
            ax_error.semilogy(
                rho_values,
                ntx_error_vs_redl,
                marker="^",
                lw=1.6,
                color="#009e73",
                label="NTX+NEOPAX vs Redl",
            )
            ax_error.axhline(1.0e-1, color="0.25", lw=1.0, ls="--", label="1e-1 gate")
            ax_error.set_ylabel("relative difference")
            ax_error.legend(fontsize=8.0)
        else:
            ax_error.text(
                0.5,
                0.5,
                "completed SFINCS-JAX outputs found;\ncomparison artifact unavailable",
                ha="center",
                va="center",
                transform=ax_error.transAxes,
            )
        ax_current.axhline(0.0, color="0.35", lw=0.8)
        ax_current.legend(fontsize=8.0)
    else:
        ax_current.text(
            0.5,
            0.5,
            "Run with --run-sfincs-jax to populate current outputs",
            ha="center",
            va="center",
            transform=ax_current.transAxes,
        )
        status_counts: dict[str, int] = {}
        for deck in decks:
            status_counts[str(deck["status"])] = status_counts.get(str(deck["status"]), 0) + 1
        labels = list(status_counts)
        ax_error.bar(labels, [status_counts[label] for label in labels], color="#0072b2")
        ax_error.set_ylabel("deck count")
        ax_error.tick_params(axis="x", rotation=25)
    ax_current.set_xlabel(r"$\rho$")
    ax_current.set_ylabel(r"$\langle J\cdot B\rangle/\sqrt{\langle B^2\rangle}$ [MA m$^{-2}$]")
    ax_current.set_title("(a) Profile-current observable")
    ax_error.set_xlabel(r"$\rho$")
    ax_error.set_title("(b) Diagnostic current gaps")

    fig.suptitle("Owned finite-beta SFINCS-JAX profile-current diagnostic", fontsize=13)
    output_prefix.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output_prefix.with_suffix(".png"), dpi=220, bbox_inches="tight")
    fig.savefig(output_prefix.with_suffix(".pdf"), bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--case", default=DEFAULT_CASE)
    parser.add_argument("--rho", nargs="+", type=float, default=list(DEFAULT_RHO))
    parser.add_argument("--nu-n", nargs="+", type=float, default=list(DEFAULT_NU_N))
    parser.add_argument("--n-theta", type=int, default=DEFAULT_GRID.n_theta)
    parser.add_argument("--n-zeta", type=int, default=DEFAULT_GRID.n_zeta)
    parser.add_argument("--n-xi", type=int, default=DEFAULT_GRID.n_xi)
    parser.add_argument("--nx", type=int, default=DEFAULT_NX)
    parser.add_argument("--solver-tolerance", type=float, default=1.0e-7)
    parser.add_argument("--min-bmn-to-load", type=float, default=1.0e-5)
    parser.add_argument("--collision-operator", type=int, choices=(0, 1), default=1)
    parser.add_argument(
        "--solve-method",
        default=None,
        help=(
            "Optional SFINCS-JAX RHSMode=1 solve method override, e.g. "
            "sparse_pc_gmres. Defaults to the SFINCS-JAX auto policy."
        ),
    )
    parser.add_argument("--run-sfincs-jax", action="store_true")
    parser.add_argument("--timeout-s", type=int, default=300)
    parser.add_argument("--bootstrap-json", type=Path, default=BOOTSTRAP_JSON)
    parser.add_argument("--output-prefix", type=Path, default=OUTPUT_PREFIX)
    parser.add_argument("--output-dir", type=Path, default=WORKDIR)
    args = parser.parse_args()

    payload = build_payload(
        case_id=str(args.case),
        rho=tuple(float(value) for value in args.rho),
        nu_n=tuple(float(value) for value in args.nu_n),
        grid=GridSpec(int(args.n_theta), int(args.n_zeta), int(args.n_xi)),
        nx=int(args.nx),
        solver_tolerance=float(args.solver_tolerance),
        min_bmn_to_load=float(args.min_bmn_to_load),
        collision_operator=int(args.collision_operator),
        solve_method=(str(args.solve_method) if args.solve_method else None),
        output_dir=args.output_dir,
        run_sfincs_jax=bool(args.run_sfincs_jax),
        timeout_s=int(args.timeout_s),
        bootstrap_json=args.bootstrap_json,
    )
    write_payload(payload, args.output_prefix)
    build_figure(payload, args.output_prefix)
    print(json.dumps(payload["summary_metrics"], indent=2))


if __name__ == "__main__":
    main()
