stormlog.jax.profiler
JAX Memory Profiler.
Provides snapshot-based memory profiling, function/context profiling, and a global profiler singleton for JAX workloads.
Functions
Reset and discard the global profiler. |
|
Reset the global profiler without discarding it. |
|
Get or create the global |
|
|
Return aggregated profile summaries from the global profiler. |
|
Replace the global profiler instance. |
Classes
|
JAX memory profiler with snapshot capture and function profiling. |
|
Point-in-time JAX memory snapshot. |
|
Aggregated profiling results for a JAX session. |
- class stormlog.jax.profiler.MemorySnapshot(timestamp, name, device_memory_bytes, cpu_memory_bytes, device_id, device_memory_reserved_bytes=0, memory_stats=<factory>, operation_name=None)[source]
Bases:
objectPoint-in-time JAX memory snapshot.
- Parameters:
timestamp (float)
name (str)
device_memory_bytes (int)
cpu_memory_bytes (int)
device_id (int)
device_memory_reserved_bytes (int)
memory_stats (Dict[str, Any])
operation_name (str | None)
- timestamp: float
- name: str
- device_memory_bytes: int
- cpu_memory_bytes: int
- device_id: int
- device_memory_reserved_bytes: int = 0
- memory_stats: Dict[str, Any]
- operation_name: str | None = None
- property device_memory_mb: float
- property device_memory_reserved_mb: float
- property cpu_memory_mb: float
- class stormlog.jax.profiler.ProfileResult(start_time, end_time, peak_memory_bytes, average_memory_bytes, min_memory_bytes, snapshots=<factory>, function_profiles=<factory>)[source]
Bases:
objectAggregated profiling results for a JAX session.
- Parameters:
start_time (float)
end_time (float)
peak_memory_bytes (int)
average_memory_bytes (int)
min_memory_bytes (int)
snapshots (List[MemorySnapshot])
function_profiles (Dict[str, Dict[str, Any]])
- start_time: float
- end_time: float
- peak_memory_bytes: int
- average_memory_bytes: int
- min_memory_bytes: int
- snapshots: List[MemorySnapshot]
- function_profiles: Dict[str, Dict[str, Any]]
- property duration: float
- property peak_memory_mb: float
- property average_memory_mb: float
- property min_memory_mb: float
- property memory_growth_rate: float
Memory growth rate in bytes/second.
- class stormlog.jax.profiler.JAXMemoryProfiler(device_index=0)[source]
Bases:
objectJAX memory profiler with snapshot capture and function profiling.
Provides:
capture_snapshot()– point-in-time device memory readingprofile_function()– decorator-based before/after profilingprofile_context()– with-block profilingstart_continuous_profiling()/stop_continuous_profiling()get_results()– aggregate intoProfileResult
Example:
profiler = JAXMemoryProfiler() with profiler: s = profiler.capture_snapshot("after_init") result = profiler.get_results()
- Parameters:
device_index (int)
- capture_snapshot(name='snapshot', *, operation_name=None)[source]
Capture a point-in-time memory snapshot.
- Parameters:
name (str) – Human-readable label for this snapshot.
operation_name (str | None) – Optional operation being profiled.
- Returns:
- Return type:
- profile_function(func=None, *, name=None)[source]
Decorator that profiles a function’s memory impact.
Usage:
@profiler.profile_function def train_step(): ... # or with options: @profiler.profile_function(name="custom_name") def train_step(): ...
- Parameters:
func (F | None)
name (str | None)
- Return type:
Any
- profile_context(name='context')[source]
Context manager that captures before/after snapshots.
Usage:
with profiler.profile_context("matmul") as snap_before: result = jax.numpy.dot(a, b)
- Parameters:
name (str)
- Return type:
Iterator[MemorySnapshot]
- start_continuous_profiling(interval=1.0)[source]
Start background snapshot capture at interval seconds.
- Parameters:
interval (float)
- Return type:
None
- get_results()[source]
Aggregate captured snapshots into a
ProfileResult.- Return type:
- stormlog.jax.profiler.get_global_profiler()[source]
Get or create the global
JAXMemoryProfilerinstance.- Return type:
- stormlog.jax.profiler.set_global_profiler(profiler)[source]
Replace the global profiler instance.
- Parameters:
profiler (JAXMemoryProfiler)
- Return type:
None
- stormlog.jax.profiler.clear_global_profiler()[source]
Reset and discard the global profiler.
- Return type:
None