stormlog.jax.context_profiler

JAX Context Profiling.

Provides module-level convenience functions, a high-level JAXProfiler class, and a ProfiledFunction wrapper analogous to the ProfiledModule/ProfiledLayer classes in the PyTorch and TensorFlow context profilers.

Because JAX follows a functional paradigm (no nn.Module hierarchy), ProfiledFunction wraps arbitrary callables rather than layer objects.

Functions

clear_global_profiler()

Reset and discard the global profiler.

clear_profiles()

Reset profiling data without discarding the global profiler.

get_global_profiler()

Get or create the global JAXMemoryProfiler instance.

get_profile_summaries([limit])

Return aggregated profiling summaries from the global profiler.

profile_context([name, profiler])

Context manager for profiling a block of code.

profile_function([func, name, profiler])

Decorator to profile a function's JAX device-memory usage.

set_global_profiler(profiler)

Replace the global JAXMemoryProfiler instance.

Classes

JAXProfiler([device_index])

High-level JAX profiling interface.

ProfiledFunction(func[, profiler, name])

Wrapper that automatically profiles every call to a function.

stormlog.jax.context_profiler.get_global_profiler()[source]

Get or create the global JAXMemoryProfiler instance.

Delegates to the singleton managed in stormlog.jax.profiler.

Returns:

The global JAXMemoryProfiler.

Return type:

JAXMemoryProfiler

stormlog.jax.context_profiler.set_global_profiler(profiler)[source]

Replace the global JAXMemoryProfiler instance.

Parameters:

profiler (JAXMemoryProfiler) – New profiler instance to install as the global singleton.

Return type:

None

stormlog.jax.context_profiler.clear_global_profiler()[source]

Reset and discard the global profiler.

After this call, get_global_profiler() will create a fresh instance on the next invocation.

Return type:

None

stormlog.jax.context_profiler.clear_profiles()[source]

Reset profiling data without discarding the global profiler.

Return type:

None

stormlog.jax.context_profiler.get_profile_summaries(limit=None)[source]

Return aggregated profiling summaries from the global profiler.

Parameters:

limit (int | None) – Maximum number of summaries to return.

Returns:

A list of dicts, one per profiled function/context block, sorted by peak memory (descending).

Return type:

List[Dict[str, Any]]

stormlog.jax.context_profiler.profile_function(func=None, *, name=None, profiler=None)[source]

Decorator to profile a function’s JAX device-memory usage.

Can be used bare (@profile_function) or with keyword arguments (@profile_function(name="custom")).

Parameters:
  • func (F | None) – Function to profile (when used as @profile_function).

  • name (str | None) – Custom name for the profiled function. Defaults to func.__name__.

  • profiler (JAXMemoryProfiler | None) – Explicit JAXMemoryProfiler to use. Falls back to the global profiler when None.

Returns:

Decorated callable (or decorator factory when called with keyword arguments).

Return type:

Callable[[F], F] | F

stormlog.jax.context_profiler.profile_context(name='context', profiler=None)[source]

Context manager for profiling a block of code.

Parameters:
  • name (str) – Label for the profiled block.

  • profiler (JAXMemoryProfiler | None) – Explicit profiler. Falls back to the global profiler.

Yields:

The JAXMemoryProfiler being used.

Return type:

Iterator[JAXMemoryProfiler]

Example:

with profile_context("matmul") as prof:
    result = jax.numpy.dot(a, b)
class stormlog.jax.context_profiler.ProfiledFunction(func, profiler=None, name=None)[source]

Bases: object

Wrapper that automatically profiles every call to a function.

JAX does not use an nn.Module class hierarchy, so instead of wrapping a layer/module this class wraps any callable (pure functions, closures, jax.jit-compiled functions, etc.).

Parameters:
  • func (Callable[..., Any]) – The callable to profile.

  • profiler (Optional['stormlog.jax.profiler.JAXMemoryProfiler']) – Explicit JAXMemoryProfiler. Falls back to the global profiler when None.

  • name (Optional[str]) – Label used in profiling output. Defaults to the callable’s __name__ or __class__.__name__.

Example:

profiled_forward = ProfiledFunction(forward_fn, name="forward")
output = profiled_forward(params, batch)
class stormlog.jax.context_profiler.JAXProfiler(device_index=0)[source]

Bases: object

High-level JAX profiling interface.

Provides convenience methods for profiling training loops and inference passes, analogous to stormlog.context_profiler.MemoryProfiler (PyTorch) and stormlog.tensorflow.context_profiler.TensorFlowProfiler.

Parameters:

device_index (int) – Index of the JAX device to monitor (default 0).

Example:

jp = JAXProfiler()
jp.profile_training(train_step, dataset, epochs=3)
result = jp.get_results()
profile_training(train_step_fn, dataset, epochs=1, steps_per_epoch=None)[source]

Profile a JAX training loop.

The caller supplies a train_step_fn that is invoked once per batch. train_step_fn should accept a single batch as its first positional argument (additional arguments can be closed over or passed through the function itself).

Parameters:
  • train_step_fn (Callable[[...], Any]) – A callable (batch) -> Any that executes a single training step.

  • dataset (Any) – An iterable of batches (must be re-iterable for multi-epoch training; generators are exhausted after epoch 0). Each epoch iterates over the full dataset (or up to steps_per_epoch batches).

  • epochs (int) – Number of epochs to profile.

  • steps_per_epoch (int | None) – Optional cap on the number of steps per epoch.

Return type:

None

profile_inference(inference_fn, data, batch_size=32)[source]

Profile a JAX inference pass.

If data is an iterable of batches it is consumed directly; otherwise it is treated as a single array-like and sliced into batches of batch_size.

Parameters:
  • inference_fn (Callable[[...], Any]) – A callable (batch) -> Any that runs inference on a single batch.

  • data (Any) – Input data – either an iterable of batches or a single array-like with a leading batch dimension.

  • batch_size (int) – Batch size used when data must be sliced.

Return type:

None

get_results()[source]

Return the aggregated ProfileResult.

Return type:

ProfileResult

reset()[source]

Clear all captured profiling data.

Return type:

None