← Back to Cookbook Index

JAX Production Recipes

This guide covers operational recipes for monitoring and troubleshooting JAX workloads with Stormlog.

Profiling jax.jit functions

JAX uses XLA compilation under the hood, and caching is critical for performance and memory efficiency. You can profile jax.jit functions identically to standard JAX operations, and Stormlog will correctly track the underlying XLA allocations.

import jax
import jax.numpy as jnp
from stormlog.jax import JAXMemoryProfiler

profiler = JAXMemoryProfiler()

@jax.jit
def fast_training_step(x):
    return jnp.dot(x, x)

with profiler.profile_context("jitted_step"):
    x = jnp.ones((1000, 1000))
    y = fast_training_step(x)
    y.block_until_ready()

results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")

Wrapping functions for telemetry tracking

For complex architectures or library code where context managers are intrusive, you can use the profile_function decorator to instrument a JAX function globally.

from stormlog.jax import profile_function
import jax.numpy as jnp

@profile_function(name="custom_matmul")
def custom_matmul(a, b):
    # This block will be transparently profiled
    res = jnp.dot(a, b)
    res.block_until_ready()
    return res

Hardware and Device Placement

Stormlog correctly attributes memory tracking back to JAX devices. If you are operating on a multi-GPU/TPU setup and using jax.sharding or jax.pmap, Stormlog will aggregate memory profiles across the requested device scopes.

Ensure that the tracking target matches your runtime:

  • CUDA: Requires jax[cuda12]

  • TPU: Requires jax[tpu]

  • CPU: Standard jax installation (used by jaxmemprof automatically if no accelerators are present)

Advanced memory analytics

If you have exported a jax_track.json log using the CLI, you can pipe it into the Python API for offline heuristics (e.g. fragmentation checks or leak detection).

from stormlog.jax.analyzer import MemoryAnalyzer
from stormlog.telemetry import TelemetryEventV2

# Assuming you loaded tracking events from a JSON log
events = [] # load JSON events

analyzer = MemoryAnalyzer()
findings = analyzer.analyze_memory_gaps(events)

for finding in findings:
    print(f"Gap detected: {finding.severity}")

← Back to Cookbook Index