stormlog.jax

JAX support for Stormlog.

class stormlog.jax.JAXMemoryProfiler(device_index=0)[source]

Bases: object

JAX memory profiler with snapshot capture and function profiling.

Provides:

  • capture_snapshot() – point-in-time device memory reading

  • profile_function() – decorator-based before/after profiling

  • profile_context() – with-block profiling

  • start_continuous_profiling() / stop_continuous_profiling()

  • get_results() – aggregate into ProfileResult

Example:

profiler = JAXMemoryProfiler()
with profiler:
    s = profiler.capture_snapshot("after_init")
result = profiler.get_results()
Parameters:

device_index (int)

capture_snapshot(name='snapshot', *, operation_name=None)[source]

Capture a point-in-time memory snapshot.

Parameters:
  • name (str) – Human-readable label for this snapshot.

  • operation_name (str | None) – Optional operation being profiled.

Returns:

A MemorySnapshot.

Return type:

MemorySnapshot

profile_function(func=None, *, name=None)[source]

Decorator that profiles a function’s memory impact.

Usage:

@profiler.profile_function
def train_step():
    ...

# or with options:
@profiler.profile_function(name="custom_name")
def train_step():
    ...
Parameters:
  • func (F | None)

  • name (str | None)

Return type:

Any

profile_context(name='context')[source]

Context manager that captures before/after snapshots.

Usage:

with profiler.profile_context("matmul") as snap_before:
    result = jax.numpy.dot(a, b)
Parameters:

name (str)

Return type:

Iterator[MemorySnapshot]

start_continuous_profiling(interval=1.0)[source]

Start background snapshot capture at interval seconds.

Parameters:

interval (float)

Return type:

None

stop_continuous_profiling()[source]

Stop the background snapshot loop.

Return type:

None

get_results()[source]

Aggregate captured snapshots into a ProfileResult.

Return type:

ProfileResult

reset()[source]

Clear all captured data.

Return type:

None

class stormlog.jax.MemoryTracker(sampling_interval=1.0, alert_threshold_mb=None, device_index=0, enable_logging=True, max_history=10000, job_id=None, rank=None, local_rank=None, world_size=None, telemetry_sink_config=None, save_device_profile_on_stop=False, enable_oom_flight_recorder=False, oom_dump_dir='oom_dumps', oom_buffer_size=None, oom_max_dumps=5, oom_max_total_mb=256)[source]

Bases: object

Real-time JAX device memory tracker.

Parameters:
  • sampling_interval (float)

  • alert_threshold_mb (Optional[float])

  • device_index (int)

  • enable_logging (bool)

  • max_history (int)

  • job_id (Optional[str])

  • rank (Optional[int])

  • local_rank (Optional[int])

  • world_size (Optional[int])

  • telemetry_sink_config (Optional[TelemetrySinkConfig])

  • save_device_profile_on_stop (bool)

  • enable_oom_flight_recorder (bool)

  • oom_dump_dir (str)

  • oom_buffer_size (Optional[int])

  • oom_max_dumps (int)

  • oom_max_total_mb (int)

get_session_summary()[source]
Return type:

SessionSummary | None

property oom_buffer_size: int

Resolved OOM ring-buffer size.

add_alert_callback(callback)[source]
Parameters:

callback (Callable[[Dict[str, Any]], None])

Return type:

None

remove_alert_callback(callback)[source]

Remove a previously registered alert callback.

Parameters:

callback (Callable[[Dict[str, Any]], None])

Return type:

None

set_alert_threshold(threshold_mb)[source]
Parameters:

threshold_mb (float)

Return type:

None

check_alerts()[source]
Return type:

bool

start_tracking()[source]
Return type:

None

stop_tracking()[source]
Return type:

TrackingResult

get_current_memory()[source]

Get current memory usage in MB.

Return type:

float

get_statistics()[source]
Return type:

dict[str, Any]

get_tracking_results()[source]

Get current tracking results without stopping.

Return type:

TrackingResult

enter_phase(name, *, metadata=None)[source]
Parameters:
  • name (str)

  • metadata (Dict[str, Any] | None)

Return type:

PhaseHandle

phase(name, *, metadata=None)[source]
Parameters:
  • name (str)

  • metadata (Dict[str, Any] | None)

Return type:

Iterator[PhaseHandle]

property last_oom_dump_path: str | None

Path to the most recent OOM dump bundle, or None.

handle_exception(exc, context=None, metadata=None)[source]

Capture OOM diagnostics for recognized OOM exceptions.

Parameters:
  • exc (BaseException)

  • context (str | None)

  • metadata (Dict[str, Any] | None)

Return type:

str | None

capture_oom(context='runtime', metadata=None)[source]

Capture an OOM diagnostic bundle if the wrapped block raises an OOM.

Parameters:
  • context (str)

  • metadata (Dict[str, Any] | None)

Return type:

Iterator[None]

trigger_oom_dump(exception, context=None, metadata=None)[source]

Manually trigger an OOM diagnostic dump bundle.

Parameters:
  • exception (BaseException)

  • context (str | None)

  • metadata (Dict[str, Any] | None)

Return type:

str | None

save_device_memory_profile(output_path)[source]

Save a JAX device memory profile to the given path.

Note

This method depends on jax.profiler.save_device_memory_profile which is only available on GPU/TPU backends with JAX >= 0.4.1. On CPU-only installs or older JAX versions the call is a no-op and returns False. The availability is checked at runtime via hasattr guards so no import error is raised.

Parameters:

output_path (str)

Return type:

bool

save_device_memory_profile_to_dir(output_dir=None)[source]

Save a JAX device memory profile to a directory with an auto-generated filename.

Parameters:

output_dir (str | None)

Return type:

str | None

class stormlog.jax.MemoryWatchdog(max_memory_mb=8000, cleanup_threshold_mb=6000, check_interval=5.0, device_index=0)[source]

Bases: object

Automatic memory management and cleanup for JAX workloads.

Parameters:
  • max_memory_mb (float)

  • cleanup_threshold_mb (float)

  • check_interval (float)

  • device_index (int)

add_cleanup_callback(callback)[source]

Add cleanup callback function.

Parameters:

callback (Callable[[], None])

Return type:

None

start()[source]

Start memory watchdog.

Return type:

None

stop()[source]

Stop memory watchdog.

Return type:

None

force_cleanup(aggressive=False)[source]

Force immediate memory cleanup.

Parameters:

aggressive (bool) – When True, also delete all live JAX arrays reachable via jax.live_arrays() (if available) before running garbage collection. Use with caution — this can invalidate arrays still referenced by user code.

Return type:

None

class stormlog.jax.MemoryAnalyzer(sensitivity=0.05, collective_sensitivity='medium', collective_threshold_overrides=None)[source]

Bases: object

Advanced analyzer for JAX memory profiling data.

Mirrors the public interface of the PyTorch and TensorFlow MemoryAnalyzer classes, adapting heuristics and recommendations for JAX’s XLA runtime and device memory model.

Parameters:
  • sensitivity (float)

  • collective_sensitivity (str)

  • collective_threshold_overrides (Optional[Mapping[str, Any]])

detect_memory_leaks(results)[source]

Detect potential memory leaks in JAX telemetry.

Uses linear regression over the memory-usage series to detect sustained upward drift.

Parameters:

results (Any) – Object with a memory_usage attribute (numeric sequence of at least 10 samples).

Returns:

List of leak-detection dicts with type, severity, description, and slope keys.

Return type:

List[Dict[str, Any]]

detect_patterns(results)[source]

Detect allocation patterns in JAX telemetry.

Performs simplified autocorrelation analysis to identify periodic memory usage behaviour (e.g. per-step allocations in a training loop).

Parameters:

results (Any) – Object with a memory_usage attribute (numeric sequence of at least 10 samples).

Returns:

List of detected pattern dicts.

Return type:

List[Dict[str, Any]]

analyze_fragmentation(profile_result)[source]

Analyse memory fragmentation patterns.

Computes fragmentation as 1 (used / reserved) across profiling snapshots.

Parameters:

profile_result (Any) – Profiling result with a snapshots attribute where each snapshot exposes device_memory_mb and device_memory_reserved_mb.

Returns:

Dictionary with fragmentation_score, trend, max_fragmentation, and min_fragmentation.

Return type:

Dict[str, float]

analyze_efficiency(profile_result)[source]

Analyse memory usage efficiency.

Returns a score on a 0.0–1.0 scale (1.0 = excellent). The score starts at 1.0 and is reduced by penalties for high peak memory, high growth rate, fragmentation, and detected leaks.

Parameters:

profile_result (Any) – Profiling result with peak_memory_mb and optionally memory_growth_rate, snapshots, and memory_usage.

Returns:

Efficiency score in [0.0, 1.0].

Return type:

float

correlate_with_performance(profile_result)[source]

Correlate memory usage with performance metrics.

Analyses per-function efficiency based on memory consumption and execution duration.

Parameters:

profile_result (Any) – Profiling result with function_profiles mapping function names to dicts containing calls, total_memory_delta, and total_duration.

Returns:

Dictionary with memory_duration_correlation, function_efficiency, and recommendations.

Return type:

Dict[str, Any]

score_optimization(profile_result, events=None)[source]

Generate an overall optimisation score with recommendations.

Combines memory efficiency, fragmentation, and per-function performance scores into a single summary.

Parameters:
  • profile_result (Any) – JAX profiling result object.

  • events (List | None) – Optional telemetry event series for gap analysis. When provided, the result includes gap_analysis and collective_attribution sections.

Returns:

Dictionary with overall_score, categories, top_recommendations, and priority_actions.

Return type:

Dict[str, Any]

analyze_memory_gaps(events, *, phase_resolver=None)[source]

Classify allocator-vs-device hidden memory gaps over time.

Parameters:
  • events (List) – Chronologically ordered telemetry samples.

  • phase_resolver (Any | None) – Optional PhaseReplayIndex for phase attribution.

Returns:

Prioritised list of gap findings (severity desc, confidence desc). Returns an empty list when the gap_analysis sub-package is not available.

Return type:

List

analyze_collective_attribution(events, *, phase_resolver=None)[source]

Attribute hidden-memory spikes to collective communication phases.

Parameters:
  • events (List) – Chronologically ordered telemetry samples.

  • phase_resolver (Any | None) – Optional PhaseReplayIndex for phase attribution.

Returns:

List of CollectiveAttributionResult objects. Returns an empty list when the collective_attribution sub-package is not available.

Return type:

List

class stormlog.jax.JAXProfiler(device_index=0)[source]

Bases: object

High-level JAX profiling interface.

Provides convenience methods for profiling training loops and inference passes, analogous to stormlog.context_profiler.MemoryProfiler (PyTorch) and stormlog.tensorflow.context_profiler.TensorFlowProfiler.

Parameters:

device_index (int) – Index of the JAX device to monitor (default 0).

Example:

jp = JAXProfiler()
jp.profile_training(train_step, dataset, epochs=3)
result = jp.get_results()
profile_training(train_step_fn, dataset, epochs=1, steps_per_epoch=None)[source]

Profile a JAX training loop.

The caller supplies a train_step_fn that is invoked once per batch. train_step_fn should accept a single batch as its first positional argument (additional arguments can be closed over or passed through the function itself).

Parameters:
  • train_step_fn (Callable[[...], Any]) – A callable (batch) -> Any that executes a single training step.

  • dataset (Any) – An iterable of batches (must be re-iterable for multi-epoch training; generators are exhausted after epoch 0). Each epoch iterates over the full dataset (or up to steps_per_epoch batches).

  • epochs (int) – Number of epochs to profile.

  • steps_per_epoch (int | None) – Optional cap on the number of steps per epoch.

Return type:

None

profile_inference(inference_fn, data, batch_size=32)[source]

Profile a JAX inference pass.

If data is an iterable of batches it is consumed directly; otherwise it is treated as a single array-like and sliced into batches of batch_size.

Parameters:
  • inference_fn (Callable[[...], Any]) – A callable (batch) -> Any that runs inference on a single batch.

  • data (Any) – Input data – either an iterable of batches or a single array-like with a leading batch dimension.

  • batch_size (int) – Batch size used when data must be sliced.

Return type:

None

get_results()[source]

Return the aggregated ProfileResult.

Return type:

ProfileResult

reset()[source]

Clear all captured profiling data.

Return type:

None

class stormlog.jax.ProfiledFunction(func, profiler=None, name=None)[source]

Bases: object

Wrapper that automatically profiles every call to a function.

JAX does not use an nn.Module class hierarchy, so instead of wrapping a layer/module this class wraps any callable (pure functions, closures, jax.jit-compiled functions, etc.).

Parameters:
  • func (Callable[..., Any]) – The callable to profile.

  • profiler (Optional['stormlog.jax.profiler.JAXMemoryProfiler']) – Explicit JAXMemoryProfiler. Falls back to the global profiler when None.

  • name (Optional[str]) – Label used in profiling output. Defaults to the callable’s __name__ or __class__.__name__.

Example:

profiled_forward = ProfiledFunction(forward_fn, name="forward")
output = profiled_forward(params, batch)
stormlog.jax.profile_function(func=None, *, name=None, profiler=None)[source]

Decorator to profile a function’s JAX device-memory usage.

Can be used bare (@profile_function) or with keyword arguments (@profile_function(name="custom")).

Parameters:
  • func (F | None) – Function to profile (when used as @profile_function).

  • name (str | None) – Custom name for the profiled function. Defaults to func.__name__.

  • profiler (JAXMemoryProfiler | None) – Explicit JAXMemoryProfiler to use. Falls back to the global profiler when None.

Returns:

Decorated callable (or decorator factory when called with keyword arguments).

Return type:

Callable[[F], F] | F

stormlog.jax.profile_context(name='context', profiler=None)[source]

Context manager for profiling a block of code.

Parameters:
  • name (str) – Label for the profiled block.

  • profiler (JAXMemoryProfiler | None) – Explicit profiler. Falls back to the global profiler.

Yields:

The JAXMemoryProfiler being used.

Return type:

Iterator[JAXMemoryProfiler]

Example:

with profile_context("matmul") as prof:
    result = jax.numpy.dot(a, b)
stormlog.jax.get_device_info(device_index=0)[source]

Return device kind, platform, and live memory statistics.

Parameters:

device_index (int) – Index into jax.local_devices() (default 0).

Returns:

Dictionary with keys kind, platform, device_id, process_index, memory_stats (raw dict from device.memory_stats()), and client.

Return type:

Dict[str, Any]

stormlog.jax.get_system_info()[source]

Return full system and JAX environment report.

Includes JAX version, device list, platform, Python version, CPU count, and system memory statistics.

Return type:

Dict[str, Any]

Modules

analyzer

JAX Memory Analysis.

attributed_viz

Stormlog-native memory visualisation for JAX (Directed Graph Dashboard).

cli

JAX Stormlog CLI

context_profiler

JAX Context Profiling.

diagnose

Diagnostic bundle builder for the JAX Stormlog diagnose command.

jax_env

JAX environment configuration for Stormlog.

pprof_parser

profile_pb2

Generated protocol buffer code.

profiler

JAX Memory Profiler.

tracker

Real-time JAX Memory Tracking.

utils

Utility functions for JAX memory profiling.

visualizer

JAX Memory Visualization