stormlog.jax
JAX support for Stormlog.
- class stormlog.jax.JAXMemoryProfiler(device_index=0)[source]
Bases:
objectJAX memory profiler with snapshot capture and function profiling.
Provides:
capture_snapshot()– point-in-time device memory readingprofile_function()– decorator-based before/after profilingprofile_context()– with-block profilingstart_continuous_profiling()/stop_continuous_profiling()get_results()– aggregate intoProfileResult
Example:
profiler = JAXMemoryProfiler() with profiler: s = profiler.capture_snapshot("after_init") result = profiler.get_results()
- Parameters:
device_index (int)
- capture_snapshot(name='snapshot', *, operation_name=None)[source]
Capture a point-in-time memory snapshot.
- Parameters:
name (str) – Human-readable label for this snapshot.
operation_name (str | None) – Optional operation being profiled.
- Returns:
A
MemorySnapshot.- Return type:
- profile_function(func=None, *, name=None)[source]
Decorator that profiles a function’s memory impact.
Usage:
@profiler.profile_function def train_step(): ... # or with options: @profiler.profile_function(name="custom_name") def train_step(): ...
- Parameters:
func (F | None)
name (str | None)
- Return type:
Any
- profile_context(name='context')[source]
Context manager that captures before/after snapshots.
Usage:
with profiler.profile_context("matmul") as snap_before: result = jax.numpy.dot(a, b)
- Parameters:
name (str)
- Return type:
Iterator[MemorySnapshot]
- start_continuous_profiling(interval=1.0)[source]
Start background snapshot capture at interval seconds.
- Parameters:
interval (float)
- Return type:
None
- class stormlog.jax.MemoryTracker(sampling_interval=1.0, alert_threshold_mb=None, device_index=0, enable_logging=True, max_history=10000, job_id=None, rank=None, local_rank=None, world_size=None, telemetry_sink_config=None, save_device_profile_on_stop=False, enable_oom_flight_recorder=False, oom_dump_dir='oom_dumps', oom_buffer_size=None, oom_max_dumps=5, oom_max_total_mb=256)[source]
Bases:
objectReal-time JAX device memory tracker.
- Parameters:
sampling_interval (float)
alert_threshold_mb (Optional[float])
device_index (int)
enable_logging (bool)
max_history (int)
job_id (Optional[str])
rank (Optional[int])
local_rank (Optional[int])
world_size (Optional[int])
telemetry_sink_config (Optional[TelemetrySinkConfig])
save_device_profile_on_stop (bool)
enable_oom_flight_recorder (bool)
oom_dump_dir (str)
oom_buffer_size (Optional[int])
oom_max_dumps (int)
oom_max_total_mb (int)
- get_session_summary()[source]
- Return type:
SessionSummary | None
- property oom_buffer_size: int
Resolved OOM ring-buffer size.
- add_alert_callback(callback)[source]
- Parameters:
callback (Callable[[Dict[str, Any]], None])
- Return type:
None
- remove_alert_callback(callback)[source]
Remove a previously registered alert callback.
- Parameters:
callback (Callable[[Dict[str, Any]], None])
- Return type:
None
- enter_phase(name, *, metadata=None)[source]
- Parameters:
name (str)
metadata (Dict[str, Any] | None)
- Return type:
- phase(name, *, metadata=None)[source]
- Parameters:
name (str)
metadata (Dict[str, Any] | None)
- Return type:
Iterator[PhaseHandle]
- property last_oom_dump_path: str | None
Path to the most recent OOM dump bundle, or None.
- handle_exception(exc, context=None, metadata=None)[source]
Capture OOM diagnostics for recognized OOM exceptions.
- Parameters:
exc (BaseException)
context (str | None)
metadata (Dict[str, Any] | None)
- Return type:
str | None
- capture_oom(context='runtime', metadata=None)[source]
Capture an OOM diagnostic bundle if the wrapped block raises an OOM.
- Parameters:
context (str)
metadata (Dict[str, Any] | None)
- Return type:
Iterator[None]
- trigger_oom_dump(exception, context=None, metadata=None)[source]
Manually trigger an OOM diagnostic dump bundle.
- Parameters:
exception (BaseException)
context (str | None)
metadata (Dict[str, Any] | None)
- Return type:
str | None
- save_device_memory_profile(output_path)[source]
Save a JAX device memory profile to the given path.
Note
This method depends on
jax.profiler.save_device_memory_profilewhich is only available on GPU/TPU backends with JAX >= 0.4.1. On CPU-only installs or older JAX versions the call is a no-op and returnsFalse. The availability is checked at runtime viahasattrguards so no import error is raised.- Parameters:
output_path (str)
- Return type:
bool
- class stormlog.jax.MemoryWatchdog(max_memory_mb=8000, cleanup_threshold_mb=6000, check_interval=5.0, device_index=0)[source]
Bases:
objectAutomatic memory management and cleanup for JAX workloads.
- Parameters:
max_memory_mb (float)
cleanup_threshold_mb (float)
check_interval (float)
device_index (int)
- add_cleanup_callback(callback)[source]
Add cleanup callback function.
- Parameters:
callback (Callable[[], None])
- Return type:
None
- force_cleanup(aggressive=False)[source]
Force immediate memory cleanup.
- Parameters:
aggressive (bool) – When True, also delete all live JAX arrays reachable via
jax.live_arrays()(if available) before running garbage collection. Use with caution — this can invalidate arrays still referenced by user code.- Return type:
None
- class stormlog.jax.MemoryAnalyzer(sensitivity=0.05, collective_sensitivity='medium', collective_threshold_overrides=None)[source]
Bases:
objectAdvanced analyzer for JAX memory profiling data.
Mirrors the public interface of the PyTorch and TensorFlow
MemoryAnalyzerclasses, adapting heuristics and recommendations for JAX’s XLA runtime and device memory model.- Parameters:
sensitivity (float)
collective_sensitivity (str)
collective_threshold_overrides (Optional[Mapping[str, Any]])
- detect_memory_leaks(results)[source]
Detect potential memory leaks in JAX telemetry.
Uses linear regression over the memory-usage series to detect sustained upward drift.
- Parameters:
results (Any) – Object with a
memory_usageattribute (numeric sequence of at least 10 samples).- Returns:
List of leak-detection dicts with
type,severity,description, andslopekeys.- Return type:
List[Dict[str, Any]]
- detect_patterns(results)[source]
Detect allocation patterns in JAX telemetry.
Performs simplified autocorrelation analysis to identify periodic memory usage behaviour (e.g. per-step allocations in a training loop).
- Parameters:
results (Any) – Object with a
memory_usageattribute (numeric sequence of at least 10 samples).- Returns:
List of detected pattern dicts.
- Return type:
List[Dict[str, Any]]
- analyze_fragmentation(profile_result)[source]
Analyse memory fragmentation patterns.
Computes fragmentation as
1 − (used / reserved)across profiling snapshots.- Parameters:
profile_result (Any) – Profiling result with a
snapshotsattribute where each snapshot exposesdevice_memory_mbanddevice_memory_reserved_mb.- Returns:
Dictionary with
fragmentation_score,trend,max_fragmentation, andmin_fragmentation.- Return type:
Dict[str, float]
- analyze_efficiency(profile_result)[source]
Analyse memory usage efficiency.
Returns a score on a 0.0–1.0 scale (1.0 = excellent). The score starts at 1.0 and is reduced by penalties for high peak memory, high growth rate, fragmentation, and detected leaks.
- Parameters:
profile_result (Any) – Profiling result with
peak_memory_mband optionallymemory_growth_rate,snapshots, andmemory_usage.- Returns:
Efficiency score in [0.0, 1.0].
- Return type:
float
- correlate_with_performance(profile_result)[source]
Correlate memory usage with performance metrics.
Analyses per-function efficiency based on memory consumption and execution duration.
- Parameters:
profile_result (Any) – Profiling result with
function_profilesmapping function names to dicts containingcalls,total_memory_delta, andtotal_duration.- Returns:
Dictionary with
memory_duration_correlation,function_efficiency, andrecommendations.- Return type:
Dict[str, Any]
- score_optimization(profile_result, events=None)[source]
Generate an overall optimisation score with recommendations.
Combines memory efficiency, fragmentation, and per-function performance scores into a single summary.
- Parameters:
profile_result (Any) – JAX profiling result object.
events (List | None) – Optional telemetry event series for gap analysis. When provided, the result includes
gap_analysisandcollective_attributionsections.
- Returns:
Dictionary with
overall_score,categories,top_recommendations, andpriority_actions.- Return type:
Dict[str, Any]
- analyze_memory_gaps(events, *, phase_resolver=None)[source]
Classify allocator-vs-device hidden memory gaps over time.
- Parameters:
events (List) – Chronologically ordered telemetry samples.
phase_resolver (Any | None) – Optional
PhaseReplayIndexfor phase attribution.
- Returns:
Prioritised list of gap findings (severity desc, confidence desc). Returns an empty list when the
gap_analysissub-package is not available.- Return type:
List
- analyze_collective_attribution(events, *, phase_resolver=None)[source]
Attribute hidden-memory spikes to collective communication phases.
- Parameters:
events (List) – Chronologically ordered telemetry samples.
phase_resolver (Any | None) – Optional
PhaseReplayIndexfor phase attribution.
- Returns:
List of
CollectiveAttributionResultobjects. Returns an empty list when thecollective_attributionsub-package is not available.- Return type:
List
- class stormlog.jax.JAXProfiler(device_index=0)[source]
Bases:
objectHigh-level JAX profiling interface.
Provides convenience methods for profiling training loops and inference passes, analogous to
stormlog.context_profiler.MemoryProfiler(PyTorch) andstormlog.tensorflow.context_profiler.TensorFlowProfiler.- Parameters:
device_index (int) – Index of the JAX device to monitor (default
0).
Example:
jp = JAXProfiler() jp.profile_training(train_step, dataset, epochs=3) result = jp.get_results()
- profile_training(train_step_fn, dataset, epochs=1, steps_per_epoch=None)[source]
Profile a JAX training loop.
The caller supplies a train_step_fn that is invoked once per batch.
train_step_fnshould accept a single batch as its first positional argument (additional arguments can be closed over or passed through the function itself).- Parameters:
train_step_fn (Callable[[...], Any]) – A callable
(batch) -> Anythat executes a single training step.dataset (Any) – An iterable of batches (must be re-iterable for multi-epoch training; generators are exhausted after epoch 0). Each epoch iterates over the full dataset (or up to steps_per_epoch batches).
epochs (int) – Number of epochs to profile.
steps_per_epoch (int | None) – Optional cap on the number of steps per epoch.
- Return type:
None
- profile_inference(inference_fn, data, batch_size=32)[source]
Profile a JAX inference pass.
If data is an iterable of batches it is consumed directly; otherwise it is treated as a single array-like and sliced into batches of batch_size.
- Parameters:
inference_fn (Callable[[...], Any]) – A callable
(batch) -> Anythat runs inference on a single batch.data (Any) – Input data – either an iterable of batches or a single array-like with a leading batch dimension.
batch_size (int) – Batch size used when data must be sliced.
- Return type:
None
- class stormlog.jax.ProfiledFunction(func, profiler=None, name=None)[source]
Bases:
objectWrapper that automatically profiles every call to a function.
JAX does not use an
nn.Moduleclass hierarchy, so instead of wrapping a layer/module this class wraps any callable (pure functions, closures,jax.jit-compiled functions, etc.).- Parameters:
func (Callable[..., Any]) – The callable to profile.
profiler (Optional['stormlog.jax.profiler.JAXMemoryProfiler']) – Explicit
JAXMemoryProfiler. Falls back to the global profiler when None.name (Optional[str]) – Label used in profiling output. Defaults to the callable’s
__name__or__class__.__name__.
Example:
profiled_forward = ProfiledFunction(forward_fn, name="forward") output = profiled_forward(params, batch)
- stormlog.jax.profile_function(func=None, *, name=None, profiler=None)[source]
Decorator to profile a function’s JAX device-memory usage.
Can be used bare (
@profile_function) or with keyword arguments (@profile_function(name="custom")).- Parameters:
func (F | None) – Function to profile (when used as
@profile_function).name (str | None) – Custom name for the profiled function. Defaults to
func.__name__.profiler (JAXMemoryProfiler | None) – Explicit
JAXMemoryProfilerto use. Falls back to the global profiler when None.
- Returns:
Decorated callable (or decorator factory when called with keyword arguments).
- Return type:
Callable[[F], F] | F
- stormlog.jax.profile_context(name='context', profiler=None)[source]
Context manager for profiling a block of code.
- Parameters:
name (str) – Label for the profiled block.
profiler (JAXMemoryProfiler | None) – Explicit profiler. Falls back to the global profiler.
- Yields:
The
JAXMemoryProfilerbeing used.- Return type:
Iterator[JAXMemoryProfiler]
Example:
with profile_context("matmul") as prof: result = jax.numpy.dot(a, b)
- stormlog.jax.get_device_info(device_index=0)[source]
Return device kind, platform, and live memory statistics.
- Parameters:
device_index (int) – Index into
jax.local_devices()(default 0).- Returns:
Dictionary with keys
kind,platform,device_id,process_index,memory_stats(raw dict fromdevice.memory_stats()), andclient.- Return type:
Dict[str, Any]
- stormlog.jax.get_system_info()[source]
Return full system and JAX environment report.
Includes JAX version, device list, platform, Python version, CPU count, and system memory statistics.
- Return type:
Dict[str, Any]
Modules
JAX Memory Analysis. |
|
Stormlog-native memory visualisation for JAX (Directed Graph Dashboard). |
|
JAX Stormlog CLI |
|
JAX Context Profiling. |
|
Diagnostic bundle builder for the JAX Stormlog diagnose command. |
|
JAX environment configuration for Stormlog. |
|
Generated protocol buffer code. |
|
JAX Memory Profiler. |
|
Real-time JAX Memory Tracking. |
|
Utility functions for JAX memory profiling. |
|
JAX Memory Visualization |