stormlog.jax.tracker
Real-time JAX Memory Tracking.
This module provides real-time monitoring of JAX device memory usage, integrating with Stormlog’s shared telemetry, session, and phase tracking infrastructure.
Classes
alias of |
|
|
Real-time JAX device memory tracker. |
|
Automatic memory management and cleanup for JAX workloads. |
|
Results from real-time JAX memory tracking. |
- class stormlog.jax.tracker.TrackingResult(start_time, end_time, samples_collected, peak_memory_bytes, min_memory_bytes, average_memory_bytes, alert_count, session_summary=None, telemetry_events=<factory>, memory_usage=<factory>, timestamps=<factory>, device_memory_profile_path=None, history_window_limit=0, history_retained_samples=0, history_dropped_samples=0, history_retained_events=0, history_dropped_events=0, history_retained_alerts=0, history_dropped_alerts=0)[source]
Bases:
objectResults from real-time JAX memory tracking.
- Parameters:
start_time (float)
end_time (float)
samples_collected (int)
peak_memory_bytes (int)
min_memory_bytes (int)
average_memory_bytes (int)
alert_count (int)
session_summary (SessionSummary | None)
telemetry_events (List[Dict[str, Any]])
memory_usage (List[int])
timestamps (List[float])
device_memory_profile_path (str | None)
history_window_limit (int)
history_retained_samples (int)
history_dropped_samples (int)
history_retained_events (int)
history_dropped_events (int)
history_retained_alerts (int)
history_dropped_alerts (int)
- start_time: float
- end_time: float
- samples_collected: int
- peak_memory_bytes: int
- min_memory_bytes: int
- average_memory_bytes: int
- alert_count: int
- session_summary: SessionSummary | None = None
- telemetry_events: List[Dict[str, Any]]
- memory_usage: List[int]
- timestamps: List[float]
- device_memory_profile_path: str | None = None
- history_window_limit: int = 0
- history_retained_samples: int = 0
- history_dropped_samples: int = 0
- history_retained_events: int = 0
- history_dropped_events: int = 0
- history_retained_alerts: int = 0
- history_dropped_alerts: int = 0
- property peak_memory_mb: float
Peak memory usage in MB.
- property average_memory_mb: float
Average memory usage in MB.
- property duration: float
Total tracking duration in seconds.
- class stormlog.jax.tracker.MemoryTracker(sampling_interval=1.0, alert_threshold_mb=None, device_index=0, enable_logging=True, max_history=10000, job_id=None, rank=None, local_rank=None, world_size=None, telemetry_sink_config=None, save_device_profile_on_stop=False, enable_oom_flight_recorder=False, oom_dump_dir='oom_dumps', oom_buffer_size=None, oom_max_dumps=5, oom_max_total_mb=256)[source]
Bases:
objectReal-time JAX device memory tracker.
- Parameters:
sampling_interval (float)
alert_threshold_mb (Optional[float])
device_index (int)
enable_logging (bool)
max_history (int)
job_id (Optional[str])
rank (Optional[int])
local_rank (Optional[int])
world_size (Optional[int])
telemetry_sink_config (Optional[TelemetrySinkConfig])
save_device_profile_on_stop (bool)
enable_oom_flight_recorder (bool)
oom_dump_dir (str)
oom_buffer_size (Optional[int])
oom_max_dumps (int)
oom_max_total_mb (int)
- get_session_summary()[source]
- Return type:
SessionSummary | None
- property oom_buffer_size: int
Resolved OOM ring-buffer size.
- add_alert_callback(callback)[source]
- Parameters:
callback (Callable[[Dict[str, Any]], None])
- Return type:
None
- remove_alert_callback(callback)[source]
Remove a previously registered alert callback.
- Parameters:
callback (Callable[[Dict[str, Any]], None])
- Return type:
None
- enter_phase(name, *, metadata=None)[source]
- Parameters:
name (str)
metadata (Dict[str, Any] | None)
- Return type:
- phase(name, *, metadata=None)[source]
- Parameters:
name (str)
metadata (Dict[str, Any] | None)
- Return type:
Iterator[PhaseHandle]
- property last_oom_dump_path: str | None
Path to the most recent OOM dump bundle, or None.
- handle_exception(exc, context=None, metadata=None)[source]
Capture OOM diagnostics for recognized OOM exceptions.
- Parameters:
exc (BaseException)
context (str | None)
metadata (Dict[str, Any] | None)
- Return type:
str | None
- capture_oom(context='runtime', metadata=None)[source]
Capture an OOM diagnostic bundle if the wrapped block raises an OOM.
- Parameters:
context (str)
metadata (Dict[str, Any] | None)
- Return type:
Iterator[None]
- trigger_oom_dump(exception, context=None, metadata=None)[source]
Manually trigger an OOM diagnostic dump bundle.
- Parameters:
exception (BaseException)
context (str | None)
metadata (Dict[str, Any] | None)
- Return type:
str | None
- save_device_memory_profile(output_path)[source]
Save a JAX device memory profile to the given path.
Note
This method depends on
jax.profiler.save_device_memory_profilewhich is only available on GPU/TPU backends with JAX >= 0.4.1. On CPU-only installs or older JAX versions the call is a no-op and returnsFalse. The availability is checked at runtime viahasattrguards so no import error is raised.- Parameters:
output_path (str)
- Return type:
bool
- class stormlog.jax.tracker.MemoryWatchdog(max_memory_mb=8000, cleanup_threshold_mb=6000, check_interval=5.0, device_index=0)[source]
Bases:
objectAutomatic memory management and cleanup for JAX workloads.
- Parameters:
max_memory_mb (float)
cleanup_threshold_mb (float)
check_interval (float)
device_index (int)
- add_cleanup_callback(callback)[source]
Add cleanup callback function.
- Parameters:
callback (Callable[[], None])
- Return type:
None
- force_cleanup(aggressive=False)[source]
Force immediate memory cleanup.
- Parameters:
aggressive (bool) – When True, also delete all live JAX arrays reachable via
jax.live_arrays()(if available) before running garbage collection. Use with caution — this can invalidate arrays still referenced by user code.- Return type:
None
- stormlog.jax.tracker.JAXMemoryTracker
alias of
MemoryTracker