stormlog.jax.profiler

JAX Memory Profiler.

Provides snapshot-based memory profiling, function/context profiling, and a global profiler singleton for JAX workloads.

Functions

clear_global_profiler()

Reset and discard the global profiler.

clear_profiles()

Reset the global profiler without discarding it.

get_global_profiler()

Get or create the global JAXMemoryProfiler instance.

get_profile_summaries([limit])

Return aggregated profile summaries from the global profiler.

set_global_profiler(profiler)

Replace the global profiler instance.

Classes

JAXMemoryProfiler([device_index])

JAX memory profiler with snapshot capture and function profiling.

MemorySnapshot(timestamp, name, ...[, ...])

Point-in-time JAX memory snapshot.

ProfileResult(start_time, end_time, ...[, ...])

Aggregated profiling results for a JAX session.

class stormlog.jax.profiler.MemorySnapshot(timestamp, name, device_memory_bytes, cpu_memory_bytes, device_id, device_memory_reserved_bytes=0, memory_stats=<factory>, operation_name=None)[source]

Bases: object

Point-in-time JAX memory snapshot.

Parameters:
  • timestamp (float)

  • name (str)

  • device_memory_bytes (int)

  • cpu_memory_bytes (int)

  • device_id (int)

  • device_memory_reserved_bytes (int)

  • memory_stats (Dict[str, Any])

  • operation_name (str | None)

timestamp: float
name: str
device_memory_bytes: int
cpu_memory_bytes: int
device_id: int
device_memory_reserved_bytes: int = 0
memory_stats: Dict[str, Any]
operation_name: str | None = None
property device_memory_mb: float
property device_memory_reserved_mb: float
property cpu_memory_mb: float
class stormlog.jax.profiler.ProfileResult(start_time, end_time, peak_memory_bytes, average_memory_bytes, min_memory_bytes, snapshots=<factory>, function_profiles=<factory>)[source]

Bases: object

Aggregated profiling results for a JAX session.

Parameters:
  • start_time (float)

  • end_time (float)

  • peak_memory_bytes (int)

  • average_memory_bytes (int)

  • min_memory_bytes (int)

  • snapshots (List[MemorySnapshot])

  • function_profiles (Dict[str, Dict[str, Any]])

start_time: float
end_time: float
peak_memory_bytes: int
average_memory_bytes: int
min_memory_bytes: int
snapshots: List[MemorySnapshot]
function_profiles: Dict[str, Dict[str, Any]]
property duration: float
property peak_memory_mb: float
property average_memory_mb: float
property min_memory_mb: float
property memory_growth_rate: float

Memory growth rate in bytes/second.

class stormlog.jax.profiler.JAXMemoryProfiler(device_index=0)[source]

Bases: object

JAX memory profiler with snapshot capture and function profiling.

Provides:

  • capture_snapshot() – point-in-time device memory reading

  • profile_function() – decorator-based before/after profiling

  • profile_context() – with-block profiling

  • start_continuous_profiling() / stop_continuous_profiling()

  • get_results() – aggregate into ProfileResult

Example:

profiler = JAXMemoryProfiler()
with profiler:
    s = profiler.capture_snapshot("after_init")
result = profiler.get_results()
Parameters:

device_index (int)

capture_snapshot(name='snapshot', *, operation_name=None)[source]

Capture a point-in-time memory snapshot.

Parameters:
  • name (str) – Human-readable label for this snapshot.

  • operation_name (str | None) – Optional operation being profiled.

Returns:

A MemorySnapshot.

Return type:

MemorySnapshot

profile_function(func=None, *, name=None)[source]

Decorator that profiles a function’s memory impact.

Usage:

@profiler.profile_function
def train_step():
    ...

# or with options:
@profiler.profile_function(name="custom_name")
def train_step():
    ...
Parameters:
  • func (F | None)

  • name (str | None)

Return type:

Any

profile_context(name='context')[source]

Context manager that captures before/after snapshots.

Usage:

with profiler.profile_context("matmul") as snap_before:
    result = jax.numpy.dot(a, b)
Parameters:

name (str)

Return type:

Iterator[MemorySnapshot]

start_continuous_profiling(interval=1.0)[source]

Start background snapshot capture at interval seconds.

Parameters:

interval (float)

Return type:

None

stop_continuous_profiling()[source]

Stop the background snapshot loop.

Return type:

None

get_results()[source]

Aggregate captured snapshots into a ProfileResult.

Return type:

ProfileResult

reset()[source]

Clear all captured data.

Return type:

None

stormlog.jax.profiler.get_global_profiler()[source]

Get or create the global JAXMemoryProfiler instance.

Return type:

JAXMemoryProfiler

stormlog.jax.profiler.set_global_profiler(profiler)[source]

Replace the global profiler instance.

Parameters:

profiler (JAXMemoryProfiler)

Return type:

None

stormlog.jax.profiler.clear_global_profiler()[source]

Reset and discard the global profiler.

Return type:

None

stormlog.jax.profiler.clear_profiles()[source]

Reset the global profiler without discarding it.

Return type:

None

stormlog.jax.profiler.get_profile_summaries(limit=None)[source]

Return aggregated profile summaries from the global profiler.

Parameters:

limit (int | None) – Maximum number of summaries to return.

Returns:

A list of dicts, one per profiled function/context block.

Return type:

List[Dict[str, Any]]