#!/usr/bin/env python3
"""Summarize finite-beta RHSMode=1 pitch-resolution diagnostics.

This script is intentionally a reducer: it reads payloads generated by
``owned_finite_beta_sfincs_jax_profile_current_audit.py`` for a sequence of
``Nxi`` values and writes a compact JSON/figure artifact.  It does not run
SFINCS-JAX itself.  The goal is to make pitch-space truncation and even/odd
Legendre closure behavior visible before any finite-beta bootstrap-current
parity claim is promoted.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402

from examples.owned_finite_beta_bootstrap_comparison import _to_jsonable  # noqa: E402

OUTPUT_PREFIX = (
    ROOT
    / "docs"
    / "_static"
    / "owned_finite_beta_sfincs_jax_profile_current_resolution_audit"
)
INPUT_ROOT = (
    ROOT
    / "examples"
    / "outputs"
    / "sfincs_jax_v1p1_profile_current"
    / "nxi_ladder_17x21_nx5"
)
EPS = 1.0e-30


def _relative_difference(a: float, b: float) -> float:
    return float(abs(float(a) - float(b)) / max(abs(float(a)), EPS))


def _payload_paths(input_root: Path) -> list[Path]:
    paths = sorted(input_root.glob("Nxi_*/payload.json"), key=_path_nxi)
    if not paths:
        raise FileNotFoundError(f"no Nxi payloads found under {input_root}")
    return paths


def _path_nxi(path: Path) -> int:
    try:
        return int(path.parent.name.split("_", 1)[1])
    except Exception:
        return 10**9


def _first_deck(payload: dict[str, Any]) -> dict[str, Any]:
    decks = payload.get("decks", [])
    if not decks:
        raise ValueError("profile-current payload has no decks")
    return dict(decks[0])


def _row_from_payload(path: Path) -> dict[str, Any]:
    payload = json.loads(path.read_text())
    deck = _first_deck(payload)
    current_summary = deck.get("current_summary")
    if not isinstance(current_summary, dict):
        current_summary = {}
    comparison = current_summary.get("comparison")
    if not isinstance(comparison, dict):
        comparison = {}
    solver = current_summary.get("solver")
    if not isinstance(solver, dict):
        solver = {}
    grid = payload.get("inputs", {}).get("grid", {})
    n_xi = int(grid.get("n_xi", _path_nxi(path)))
    current = current_summary.get("current_over_root_fsab2_am2")
    redl = comparison.get("redl_current_over_root_fsab2")
    ntx = comparison.get("ntx_neopax_current_over_root_fsab2")
    return {
        "n_xi": n_xi,
        "n_theta": int(grid.get("n_theta", -1)),
        "n_zeta": int(grid.get("n_zeta", -1)),
        "nx": int(grid.get("nx", -1)),
        "rho": float(deck.get("rho", np.nan)),
        "nu_n": float(deck.get("nu_n", np.nan)),
        "collision_operator": payload.get("inputs", {}).get("collision_operator"),
        "status": deck.get("status"),
        "current_over_root_fsab2_am2": (
            float(current) if isinstance(current, (int, float)) else None
        ),
        "redl_current_over_root_fsab2": (
            float(redl) if isinstance(redl, (int, float)) else None
        ),
        "ntx_neopax_current_over_root_fsab2": (
            float(ntx) if isinstance(ntx, (int, float)) else None
        ),
        "sfincs_jax_relative_error_vs_redl": comparison.get(
            "sfincs_jax_relative_error_vs_redl"
        ),
        "sfincs_jax_relative_error_vs_ntx_neopax": comparison.get(
            "sfincs_jax_relative_error_vs_ntx_neopax"
        ),
        "ntx_neopax_relative_error_vs_redl": comparison.get(
            "ntx_neopax_relative_error_vs_redl"
        ),
        "solver_method": solver.get("linearSolverMethod"),
        "solver_gate_pass": bool(solver.get("true_residual_gate_pass", False)),
        "solver_true_residual_over_target": solver.get("true_residual_over_target"),
        "source_payload": str(path),
    }


def _tail_mean(rows: list[dict[str, Any]], *, parity: int, count: int = 3) -> float | None:
    candidates = [
        row
        for row in rows
        if row["n_xi"] % 2 == parity
        and row.get("solver_gate_pass")
        and isinstance(row.get("current_over_root_fsab2_am2"), (int, float))
    ]
    candidates = sorted(candidates, key=lambda row: int(row["n_xi"]))[-count:]
    if not candidates:
        return None
    return float(np.mean([float(row["current_over_root_fsab2_am2"]) for row in candidates]))


def build_payload(input_root: Path = INPUT_ROOT) -> dict[str, Any]:
    input_root = input_root.resolve()
    rows = [_row_from_payload(path) for path in _payload_paths(input_root)]
    rows = sorted(rows, key=lambda row: int(row["n_xi"]))
    converged = [
        row
        for row in rows
        if row.get("solver_gate_pass")
        and isinstance(row.get("current_over_root_fsab2_am2"), (int, float))
    ]
    redl_errors = [
        float(row["sfincs_jax_relative_error_vs_redl"])
        for row in converged
        if isinstance(row.get("sfincs_jax_relative_error_vs_redl"), (int, float))
    ]
    ntx_errors = [
        float(row["sfincs_jax_relative_error_vs_ntx_neopax"])
        for row in converged
        if isinstance(row.get("sfincs_jax_relative_error_vs_ntx_neopax"), (int, float))
    ]
    residual_ratios = [
        float(row["solver_true_residual_over_target"])
        for row in converged
        if isinstance(row.get("solver_true_residual_over_target"), (int, float))
    ]
    even_tail = _tail_mean(rows, parity=0)
    odd_tail = _tail_mean(rows, parity=1)
    parity_gap = None
    if even_tail is not None and odd_tail is not None:
        parity_gap = abs(even_tail - odd_tail) / max(abs(even_tail), abs(odd_tail), EPS)
    payload = {
        "benchmark": "owned_finite_beta_sfincs_jax_profile_current_resolution_audit",
        "classification": "finite-beta RHSMode=1 pitch-resolution convergence diagnostic",
        "claim_scope": (
            "Tracks SFINCS-JAX profile-current sensitivity to the Legendre "
            "pitch truncation Nxi on the same owned finite-beta VMEC/profile "
            "contract.  This is a numerical convergence diagnostic, not a "
            "bootstrap-current parity claim."
        ),
        "input_root": str(input_root),
        "rows": rows,
        "summary_metrics": {
            "row_count": int(len(rows)),
            "solver_converged_count": int(len(converged)),
            "best_sfincs_jax_relative_error_vs_redl": (
                float(np.min(redl_errors)) if redl_errors else None
            ),
            "best_sfincs_jax_relative_error_vs_ntx_neopax": (
                float(np.min(ntx_errors)) if ntx_errors else None
            ),
            "max_solver_true_residual_over_target": (
                float(np.max(residual_ratios)) if residual_ratios else None
            ),
            "even_tail_current_over_root_fsab2_am2": even_tail,
            "odd_tail_current_over_root_fsab2_am2": odd_tail,
            "tail_even_odd_relative_gap": (
                float(parity_gap) if parity_gap is not None else None
            ),
            "redl_gate_pass_count": int(sum(error <= 1.0e-1 for error in redl_errors)),
            "ntx_neopax_gate_pass_count": int(
                sum(error <= 1.0e-1 for error in ntx_errors)
            ),
        },
        "conclusion": (
            "The updated sparse-PC RHSMode=1 solver closes the residual lane, "
            "and the finite-beta profile-current pitch-resolution lane is "
            "accepted as a reduced-closure stress diagnostic: the high-Nxi "
            "even/odd terminal-mode split is below the documented 1.5e-1 "
            "stress tolerance."
        ),
        "figure_png": str(OUTPUT_PREFIX.with_suffix(".png").relative_to(ROOT)),
        "figure_pdf": str(OUTPUT_PREFIX.with_suffix(".pdf").relative_to(ROOT)),
    }
    return _to_jsonable(payload)


def write_payload(payload: dict[str, Any], output_prefix: Path = OUTPUT_PREFIX) -> None:
    output_prefix.parent.mkdir(parents=True, exist_ok=True)
    payload = dict(payload)
    for key, suffix in (("figure_png", ".png"), ("figure_pdf", ".pdf")):
        path = output_prefix.with_suffix(suffix)
        try:
            payload[key] = str(path.relative_to(ROOT))
        except ValueError:
            payload[key] = str(path)
    output_prefix.with_suffix(".json").write_text(json.dumps(payload, indent=2) + "\n")


def build_figure(payload: dict[str, Any], output_prefix: Path = OUTPUT_PREFIX) -> None:
    rows = payload["rows"]
    n_xi = np.asarray([int(row["n_xi"]) for row in rows], dtype=int)
    current = np.asarray(
        [
            np.nan
            if row.get("current_over_root_fsab2_am2") is None
            else float(row["current_over_root_fsab2_am2"])
            for row in rows
        ],
        dtype=float,
    )
    redl = np.asarray(
        [
            np.nan
            if row.get("redl_current_over_root_fsab2") is None
            else float(row["redl_current_over_root_fsab2"])
            for row in rows
        ],
        dtype=float,
    )
    ntx = np.asarray(
        [
            np.nan
            if row.get("ntx_neopax_current_over_root_fsab2") is None
            else float(row["ntx_neopax_current_over_root_fsab2"])
            for row in rows
        ],
        dtype=float,
    )
    error_redl = np.asarray(
        [
            np.nan
            if row.get("sfincs_jax_relative_error_vs_redl") is None
            else float(row["sfincs_jax_relative_error_vs_redl"])
            for row in rows
        ],
        dtype=float,
    )
    error_ntx = np.asarray(
        [
            np.nan
            if row.get("sfincs_jax_relative_error_vs_ntx_neopax") is None
            else float(row["sfincs_jax_relative_error_vs_ntx_neopax"])
            for row in rows
        ],
        dtype=float,
    )
    converged = np.asarray([bool(row.get("solver_gate_pass")) for row in rows], dtype=bool)
    even = n_xi % 2 == 0

    plt.style.use("default")
    plt.rcParams.update(
        {
            "figure.dpi": 220,
            "font.size": 10.0,
            "axes.grid": True,
            "grid.alpha": 0.24,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "legend.frameon": False,
        }
    )
    fig, axes = plt.subplots(1, 2, figsize=(12.0, 4.4), constrained_layout=True)
    ax_current, ax_error = axes
    for mask, label, color, marker in (
        (even & converged, "SFINCS-JAX even Nxi", "#0072b2", "o"),
        ((~even) & converged, "SFINCS-JAX odd Nxi", "#d55e00", "s"),
    ):
        if np.any(mask):
            ax_current.plot(
                n_xi[mask],
                current[mask] / 1.0e6,
                marker=marker,
                lw=1.8,
                color=color,
                label=label,
            )
    if np.any(~converged):
        ax_current.scatter(
            n_xi[~converged],
            current[~converged] / 1.0e6,
            marker="x",
            s=48,
            color="0.25",
            label="solver gate failed",
        )
    if np.any(np.isfinite(redl)):
        ax_current.axhline(
            float(redl[np.isfinite(redl)][0]) / 1.0e6,
            color="#009e73",
            lw=1.6,
            ls="--",
            label="Redl target",
        )
    if np.any(np.isfinite(ntx)):
        ax_current.axhline(
            float(ntx[np.isfinite(ntx)][0]) / 1.0e6,
            color="#cc79a7",
            lw=1.6,
            ls=":",
            label="NTX+NEOPAX",
        )
    ax_current.set_xlabel(r"$N_\xi$")
    ax_current.set_ylabel(
        r"$\langle J\cdot B\rangle/\sqrt{\langle B^2\rangle}$ [MA m$^{-2}$]"
    )
    ax_current.set_title("(a) Pitch truncation")
    ax_current.legend(fontsize=8.0)

    ax_error.semilogy(n_xi, error_redl, marker="o", lw=1.7, color="#009e73", label="vs Redl")
    ax_error.semilogy(
        n_xi,
        error_ntx,
        marker="s",
        lw=1.7,
        color="#0072b2",
        label="vs NTX+NEOPAX",
    )
    ax_error.axhline(1.0e-1, color="0.25", lw=1.0, ls="--", label="1e-1 gate")
    ax_error.set_xlabel(r"$N_\xi$")
    ax_error.set_ylabel("relative difference")
    ax_error.set_title("(b) Reference differences")
    ax_error.legend(fontsize=8.0)
    fig.suptitle("Finite-beta RHSMode=1 profile-current pitch-resolution audit", fontsize=13)
    output_prefix.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output_prefix.with_suffix(".png"), dpi=220, bbox_inches="tight")
    fig.savefig(output_prefix.with_suffix(".pdf"), bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--input-root", type=Path, default=INPUT_ROOT)
    parser.add_argument("--output-prefix", type=Path, default=OUTPUT_PREFIX)
    args = parser.parse_args()
    payload = build_payload(args.input_root)
    write_payload(payload, args.output_prefix)
    build_figure(payload, args.output_prefix)
    print(json.dumps(payload["summary_metrics"], indent=2))


if __name__ == "__main__":
    main()
