GPU Runs
NTX uses the same JAX solver path on CPU and GPU.
GPU Test Targets
tests/test_gpu_smoke.pyscripts/run_gpu_regression.pyscripts/sh_office_gpu_smoke.sh
Typical Session
sh office
cd /path/to/NTX
python -m pip install -e ".[dev,docs,io]"
export XLA_PYTHON_CLIENT_PREALLOCATE=false
scripts/sh_office_gpu_smoke.sh
What The Regression Script Reports
backend and visible devices
compile-plus-first-run timing
steady-state timing
solved coefficients
max relative error against NTX-owned smoke references
Device-Parallel Scans
For larger scans, NTX also exposes a device-parallel scan path through
solve_monoenergetic_parallel_scan(...) and the profiling helper:
python scripts/profile_parallel_runtime.py --output-json parallel-runtime.json
This is intended for multi-device CPU or GPU jobs when scan throughput matters
more than single-case latency.
For CI or quick local smoke checks, use
--num-cases 2 --grid 5,5,4 to keep the serial/device-parallel equality check
fast while preserving the default profiling behavior for real measurements.
The helper now performs an NTX smoke check on local devices before using them. If a visible device fails that check, it is excluded from the parallel solve instead of silently returning bad coefficients.
NTX also provides a separate multiprocess path:
python scripts/profile_multiprocess_runtime.py --backend gpu --workers 2
That path runs one Python worker per GPU with process-local
CUDA_VISIBLE_DEVICES pinning. It is the current robust route for office
hardware because it avoids the single-process cuSolver failure mode seen on
cuda:1.
Current Hardware Interpretation
The current GPU lane is numerically stable and validated on office hardware. For the small repository smoke cases, CPU remains faster in steady-state wall time. That is expected: these grids are small enough that GPU launch and transfer overheads dominate.
For the single-process profiler on office:
JAX sees two GPUs
only one passes the NTX dense-solve smoke check under the current stack
the guarded parallel path therefore runs on the healthy subset and preserves correct coefficients
For the multiprocess profiler on office:
both GPUs execute correctly when pinned to separate worker processes
coefficient deltas are zero at the repository smoke-case tolerance
wall time is still worse than the serial batched solve for the small smoke grids because process launch and per-worker compilation dominate
So the current guidance is:
use the serial batched JAX scan for small and medium studies
use the guarded single-process path only when all visible devices are healthy
use the multiprocess path for larger multi-GPU throughput workloads or for platforms that need strict one-process-per-GPU isolation
The current scaling figures and JSON payloads are documented in
Performance.