JAX Production Recipes
This guide covers operational recipes for monitoring and troubleshooting JAX workloads with Stormlog.
Profiling jax.jit functions
JAX uses XLA compilation under the hood, and caching is critical for performance and memory efficiency. You can profile jax.jit functions identically to standard JAX operations, and Stormlog will correctly track the underlying XLA allocations.
import jax
import jax.numpy as jnp
from stormlog.jax import JAXMemoryProfiler
profiler = JAXMemoryProfiler()
@jax.jit
def fast_training_step(x):
return jnp.dot(x, x)
with profiler.profile_context("jitted_step"):
x = jnp.ones((1000, 1000))
y = fast_training_step(x)
y.block_until_ready()
results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")
Wrapping functions for telemetry tracking
For complex architectures or library code where context managers are intrusive, you can use the profile_function decorator to instrument a JAX function globally.
from stormlog.jax import profile_function
import jax.numpy as jnp
@profile_function(name="custom_matmul")
def custom_matmul(a, b):
# This block will be transparently profiled
res = jnp.dot(a, b)
res.block_until_ready()
return res
Hardware and Device Placement
Stormlog correctly attributes memory tracking back to JAX devices. If you are operating on a multi-GPU/TPU setup and using jax.sharding or jax.pmap, Stormlog will aggregate memory profiles across the requested device scopes.
Ensure that the tracking target matches your runtime:
CUDA: Requires
jax[cuda12]TPU: Requires
jax[tpu]CPU: Standard
jaxinstallation (used byjaxmemprofautomatically if no accelerators are present)
Advanced memory analytics
If you have exported a jax_track.json log using the CLI, you can pipe it into the Python API for offline heuristics (e.g. fragmentation checks or leak detection).
from stormlog.jax.analyzer import MemoryAnalyzer
from stormlog.telemetry import TelemetryEventV2
# Assuming you loaded tracking events from a JSON log
events = [] # load JSON events
analyzer = MemoryAnalyzer()
findings = analyzer.analyze_memory_gaps(events)
for finding in findings:
print(f"Gap detected: {finding.severity}")