JAX Testing Guide
This guide covers the current JAX workflow in Stormlog: profiling JAX code directly, tracking JAX memory usage from the CLI, and exporting artifacts for later review.
Before you start
Validate the environment:
jaxmemprof info
If you are bringing up an accelerator runtime (CUDA or TPU), start with a basic JAX array operation before attempting complex tracking. These checks work on CPU-backed JAX installs as well.
Daily workflow: ML engineer
Use JAXMemoryProfiler when you want snapshots and aggregate results around a real JAX workload.
import jax.numpy as jnp
from stormlog.jax import JAXMemoryProfiler
profiler = JAXMemoryProfiler()
with profiler.profile_context("training"):
x = jnp.ones((1000, 1000))
y = jnp.dot(x, x)
# JAX operations are asynchronous. Block until ready.
y.block_until_ready()
results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")
print(f"Snapshots captured: {len(results.snapshots)}")
Daily workflow: investigate sustained growth
The JAX CLI is the simplest way to capture longer-running telemetry:
jaxmemprof monitor --interval 0.5 --duration 30 --output jax_monitor.json
jaxmemprof track --interval 0.5 --output jax_track.json
jaxmemprof analyze --input jax_monitor.json --detect-leaks --optimize --report jax_report.txt
jaxmemprof diagnose --duration 0 --output ./jax_diag
For CPU-backed JAX or when the accelerator backend is unavailable, jaxmemprof will automatically fallback to CPU mode. You can explicitly force it with --device cpu.
Recommended validation sequence
Use this when you need a compact JAX confidence pass:
jaxmemprof info
jaxmemprof monitor --interval 0.5 --duration 15 --output jax_monitor.json
jaxmemprof analyze --input jax_monitor.json --detect-leaks --optimize --report jax_report.txt
jaxmemprof diagnose --duration 0 --output ./jax_diag
Common issues
jaxmemprof runs on CPU when I expected GPU/TPU
Run:
jaxmemprof info
If the CLI outputs that JAX is running on CPU, you’ll need to install the specific jax variants for your hardware (e.g., jax[cuda12], jax[tpu]).
Plot export fails
Install the visualization extra:
pip install "stormlog[viz]"