stormlog.jax.context_profiler
JAX Context Profiling.
Provides module-level convenience functions, a high-level JAXProfiler
class, and a ProfiledFunction wrapper analogous to the
ProfiledModule/ProfiledLayer classes in the PyTorch and TensorFlow
context profilers.
Because JAX follows a functional paradigm (no nn.Module hierarchy),
ProfiledFunction wraps arbitrary callables rather than layer objects.
Functions
Reset and discard the global profiler. |
|
Reset profiling data without discarding the global profiler. |
|
Get or create the global |
|
|
Return aggregated profiling summaries from the global profiler. |
|
Context manager for profiling a block of code. |
|
Decorator to profile a function's JAX device-memory usage. |
|
Replace the global |
Classes
|
High-level JAX profiling interface. |
|
Wrapper that automatically profiles every call to a function. |
- stormlog.jax.context_profiler.get_global_profiler()[source]
Get or create the global
JAXMemoryProfilerinstance.Delegates to the singleton managed in
stormlog.jax.profiler.- Returns:
The global
JAXMemoryProfiler.- Return type:
- stormlog.jax.context_profiler.set_global_profiler(profiler)[source]
Replace the global
JAXMemoryProfilerinstance.- Parameters:
profiler (JAXMemoryProfiler) – New profiler instance to install as the global singleton.
- Return type:
None
- stormlog.jax.context_profiler.clear_global_profiler()[source]
Reset and discard the global profiler.
After this call,
get_global_profiler()will create a fresh instance on the next invocation.- Return type:
None
- stormlog.jax.context_profiler.clear_profiles()[source]
Reset profiling data without discarding the global profiler.
- Return type:
None
- stormlog.jax.context_profiler.get_profile_summaries(limit=None)[source]
Return aggregated profiling summaries from the global profiler.
- Parameters:
limit (int | None) – Maximum number of summaries to return.
- Returns:
A list of dicts, one per profiled function/context block, sorted by peak memory (descending).
- Return type:
List[Dict[str, Any]]
- stormlog.jax.context_profiler.profile_function(func=None, *, name=None, profiler=None)[source]
Decorator to profile a function’s JAX device-memory usage.
Can be used bare (
@profile_function) or with keyword arguments (@profile_function(name="custom")).- Parameters:
func (F | None) – Function to profile (when used as
@profile_function).name (str | None) – Custom name for the profiled function. Defaults to
func.__name__.profiler (JAXMemoryProfiler | None) – Explicit
JAXMemoryProfilerto use. Falls back to the global profiler when None.
- Returns:
Decorated callable (or decorator factory when called with keyword arguments).
- Return type:
Callable[[F], F] | F
- stormlog.jax.context_profiler.profile_context(name='context', profiler=None)[source]
Context manager for profiling a block of code.
- Parameters:
name (str) – Label for the profiled block.
profiler (JAXMemoryProfiler | None) – Explicit profiler. Falls back to the global profiler.
- Yields:
The
JAXMemoryProfilerbeing used.- Return type:
Iterator[JAXMemoryProfiler]
Example:
with profile_context("matmul") as prof: result = jax.numpy.dot(a, b)
- class stormlog.jax.context_profiler.ProfiledFunction(func, profiler=None, name=None)[source]
Bases:
objectWrapper that automatically profiles every call to a function.
JAX does not use an
nn.Moduleclass hierarchy, so instead of wrapping a layer/module this class wraps any callable (pure functions, closures,jax.jit-compiled functions, etc.).- Parameters:
func (Callable[..., Any]) – The callable to profile.
profiler (Optional['stormlog.jax.profiler.JAXMemoryProfiler']) – Explicit
JAXMemoryProfiler. Falls back to the global profiler when None.name (Optional[str]) – Label used in profiling output. Defaults to the callable’s
__name__or__class__.__name__.
Example:
profiled_forward = ProfiledFunction(forward_fn, name="forward") output = profiled_forward(params, batch)
- class stormlog.jax.context_profiler.JAXProfiler(device_index=0)[source]
Bases:
objectHigh-level JAX profiling interface.
Provides convenience methods for profiling training loops and inference passes, analogous to
stormlog.context_profiler.MemoryProfiler(PyTorch) andstormlog.tensorflow.context_profiler.TensorFlowProfiler.- Parameters:
device_index (int) – Index of the JAX device to monitor (default
0).
Example:
jp = JAXProfiler() jp.profile_training(train_step, dataset, epochs=3) result = jp.get_results()
- profile_training(train_step_fn, dataset, epochs=1, steps_per_epoch=None)[source]
Profile a JAX training loop.
The caller supplies a train_step_fn that is invoked once per batch.
train_step_fnshould accept a single batch as its first positional argument (additional arguments can be closed over or passed through the function itself).- Parameters:
train_step_fn (Callable[[...], Any]) – A callable
(batch) -> Anythat executes a single training step.dataset (Any) – An iterable of batches (must be re-iterable for multi-epoch training; generators are exhausted after epoch 0). Each epoch iterates over the full dataset (or up to steps_per_epoch batches).
epochs (int) – Number of epochs to profile.
steps_per_epoch (int | None) – Optional cap on the number of steps per epoch.
- Return type:
None
- profile_inference(inference_fn, data, batch_size=32)[source]
Profile a JAX inference pass.
If data is an iterable of batches it is consumed directly; otherwise it is treated as a single array-like and sliced into batches of batch_size.
- Parameters:
inference_fn (Callable[[...], Any]) – A callable
(batch) -> Anythat runs inference on a single batch.data (Any) – Input data – either an iterable of batches or a single array-like with a leading batch dimension.
batch_size (int) – Batch size used when data must be sliced.
- Return type:
None