Source code for stormlog.jax.tracker

"""
Real-time JAX Memory Tracking.

This module provides real-time monitoring of JAX device memory usage,
integrating with Stormlog's shared telemetry, session, and phase
tracking infrastructure.
"""

from __future__ import annotations

import json
import logging
import os
import socket
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional

from stormlog.collector_health import (
    COLLECTOR_HEALTH_HEALTHY,
    COLLECTOR_HEALTH_UNHEALTHY,
    CollectorHealthState,
    collector_retry_delay_seconds,
)
from stormlog.oom_flight_recorder import (
    OOMFlightRecorder,
    OOMFlightRecorderConfig,
    classify_oom_exception,
)
from stormlog.phases import PhaseHandle, PhaseRecorder, PhaseToken
from stormlog.session import (
    SESSION_STATUS_COMPLETED,
    SESSION_STATUS_INCOMPLETE,
    SESSION_STATUS_RUNNING,
    SessionSummary,
    create_session_summary,
    finalize_session_summary,
    now_ns,
    update_session_summary,
)
from stormlog.telemetry import (
    resolve_distributed_identity,
    telemetry_event_from_record,
    telemetry_event_to_dict,
)
from stormlog.telemetry_sink import AppendOnlyTelemetrySink, TelemetrySinkConfig

from .jax_env import configure_jax_logging

configure_jax_logging()

import jax  # noqa: E402

JAX_AVAILABLE = True

try:
    import psutil

    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False
    psutil = None

logger = logging.getLogger(__name__)


[docs] @dataclass class TrackingResult: """Results from real-time JAX memory tracking.""" start_time: float end_time: float samples_collected: int peak_memory_bytes: int min_memory_bytes: int average_memory_bytes: int alert_count: int session_summary: Optional[SessionSummary] = None telemetry_events: List[Dict[str, Any]] = field(default_factory=list) memory_usage: List[int] = field(default_factory=list) timestamps: List[float] = field(default_factory=list) device_memory_profile_path: Optional[str] = None history_window_limit: int = 0 history_retained_samples: int = 0 history_dropped_samples: int = 0 history_retained_events: int = 0 history_dropped_events: int = 0 history_retained_alerts: int = 0 history_dropped_alerts: int = 0 @property def peak_memory_mb(self) -> float: """Peak memory usage in MB.""" return self.peak_memory_bytes / (1024 * 1024) @property def average_memory_mb(self) -> float: """Average memory usage in MB.""" return self.average_memory_bytes / (1024 * 1024) @property def duration(self) -> float: """Total tracking duration in seconds.""" return self.end_time - self.start_time
@dataclass(frozen=True) class _TrackingResultData: retained_memory_usage: List[int] retained_timestamps: List[float] retained_events: List[Dict[str, Any]] retained_alerts: List[Dict[str, Any]] total_samples_observed: int peak_memory: int min_memory: int sum_memory: int dropped_samples: int dropped_events: int dropped_alerts: int
[docs] class MemoryTracker: """Real-time JAX device memory tracker.""" def __init__( self, sampling_interval: float = 1.0, alert_threshold_mb: Optional[float] = None, device_index: int = 0, enable_logging: bool = True, max_history: int = 10_000, job_id: Optional[str] = None, rank: Optional[int] = None, local_rank: Optional[int] = None, world_size: Optional[int] = None, telemetry_sink_config: Optional[TelemetrySinkConfig] = None, save_device_profile_on_stop: bool = False, enable_oom_flight_recorder: bool = False, oom_dump_dir: str = "oom_dumps", oom_buffer_size: Optional[int] = None, oom_max_dumps: int = 5, oom_max_total_mb: int = 256, ): if not JAX_AVAILABLE: raise ImportError( "JAX not available. Install with `pip install 'stormlog[jax]'`." ) if sampling_interval <= 0: raise ValueError("sampling_interval must be > 0") if max_history <= 0: raise ValueError("max_history must be >= 1") self.sampling_interval = sampling_interval self.alert_threshold_mb = alert_threshold_mb self.device_index = device_index self.enable_logging = enable_logging self.max_history = max_history self.save_device_profile_on_stop = save_device_profile_on_stop recorder_buffer_size = ( oom_buffer_size if oom_buffer_size is not None else max_history ) if recorder_buffer_size <= 0: recorder_buffer_size = max_history self._oom_flight_recorder = OOMFlightRecorder( OOMFlightRecorderConfig( enabled=enable_oom_flight_recorder, dump_dir=oom_dump_dir, buffer_size=recorder_buffer_size, max_dumps=oom_max_dumps, max_total_mb=oom_max_total_mb, ) ) self._last_oom_dump_path: Optional[str] = None self._device = None self._device_bytes_limit: Optional[int] = None self._last_reserved_bytes: Optional[int] = None try: devices = jax.local_devices() if device_index < len(devices): self._device = devices[device_index] try: stats = self._device.memory_stats() if stats and "bytes_limit" in stats: self._device_bytes_limit = int(stats["bytes_limit"]) except Exception: pass except Exception as exc: logger.debug("Could not resolve JAX device %d: %s", device_index, exc) # Cache a scalar sentinel for sync barriers — avoids re-allocating # a device array on every sample. self._sync_sentinel: Any = None if self._device is not None: try: self._sync_sentinel = jax.numpy.zeros((), device=self._device) except Exception: pass # Cache invariant per-process values used in every telemetry record. self._cached_pid = os.getpid() self._cached_hostname = socket.gethostname() self._telemetry_sink_config = telemetry_sink_config self._telemetry_sink = ( AppendOnlyTelemetrySink(telemetry_sink_config) if telemetry_sink_config is not None else None ) self.distributed_identity = resolve_distributed_identity( job_id=job_id, rank=rank, local_rank=local_rank, world_size=world_size, env=os.environ, ) self.session_source = "stormlog.jax.tracker" self._session_summary: Optional[SessionSummary] = None # Tracking state self.tracking = False self.tracking_thread: Optional[threading.Thread] = None self._memory_usage: deque[int] = deque(maxlen=max_history) self._timestamps: deque[float] = deque(maxlen=max_history) self._events: deque[Dict[str, Any]] = deque(maxlen=max_history) self._alerts: deque[Dict[str, Any]] = deque(maxlen=max_history) self._history_dropped_samples = 0 self._history_dropped_events = 0 self._history_dropped_alerts = 0 self._collector_failure_event_count = 0 self._total_samples_observed = 0 self._peak_memory_bytes = 0 self._min_memory_bytes: Optional[int] = None self._sum_memory_bytes = 0 self._last_sink_diagnostics: Dict[str, Any] = self._empty_sink_diagnostics() self._phase_state = PhaseRecorder() # Thread sync self._lock = threading.Lock() self._stop_event = threading.Event() self._collector_health = CollectorHealthState() self._last_successful_memory_bytes: Optional[int] = None self._session_start_time: Optional[float] = None self._session_end_time: Optional[float] = None self._collector_retry_backoff_initial_s = 1.0 self._collector_retry_backoff_factor = 2.0 self._collector_retry_backoff_cap_s = 30.0 self.alert_callbacks: List[Callable[[Dict[str, Any]], None]] = [] if enable_logging: logger.info( "JAX Memory Tracker initialized for device index %d", device_index ) @staticmethod def _empty_sink_diagnostics() -> Dict[str, Any]: return { "rollover_count": 0, "pruned_segment_count": 0, "pruned_bytes": 0, "final_retained_files": 0, "final_retained_bytes": 0, } def _reset_history(self) -> None: self._memory_usage.clear() self._timestamps.clear() self._events.clear() self._alerts.clear() self._history_dropped_samples = 0 self._history_dropped_events = 0 self._history_dropped_alerts = 0 self._collector_failure_event_count = 0 self._total_samples_observed = 0 self._peak_memory_bytes = 0 self._min_memory_bytes = None self._sum_memory_bytes = 0 self._last_sink_diagnostics = self._empty_sink_diagnostics() self._last_oom_dump_path = None self._oom_flight_recorder.clear() def _tracking_result_data(self) -> _TrackingResultData: return _TrackingResultData( retained_memory_usage=list(self._memory_usage), retained_timestamps=list(self._timestamps), retained_events=list(self._events), retained_alerts=list(self._alerts), total_samples_observed=self._total_samples_observed, peak_memory=self._peak_memory_bytes, min_memory=( self._min_memory_bytes if self._min_memory_bytes is not None else 0 ), sum_memory=self._sum_memory_bytes, dropped_samples=self._history_dropped_samples, dropped_events=self._history_dropped_events, dropped_alerts=self._history_dropped_alerts, ) def _ensure_session_summary(self) -> SessionSummary: if self._session_summary is None: summary = create_session_summary( source=self.session_source, status=SESSION_STATUS_RUNNING, started_at_ns=now_ns(), host=socket.gethostname(), pid=os.getpid(), job_id=self.distributed_identity["job_id"], rank=self.distributed_identity["rank"], local_rank=self.distributed_identity["local_rank"], world_size=self.distributed_identity["world_size"], ) if self._telemetry_sink is not None and hasattr( self._telemetry_sink, "start_session" ): summary = self._telemetry_sink.start_session(summary) self._session_summary = summary return self._session_summary
[docs] def get_session_summary(self) -> Optional[SessionSummary]: return self._session_summary
@property def oom_buffer_size(self) -> int: """Resolved OOM ring-buffer size.""" return self._oom_flight_recorder.config.buffer_size def _ensure_telemetry_sink(self) -> None: if self._telemetry_sink is None and self._telemetry_sink_config is not None: self._telemetry_sink = AppendOnlyTelemetrySink(self._telemetry_sink_config) def _get_current_device_memory(self) -> int: if self._device is None: return 0 # Flush XLA async dispatch before reading memory stats. # JAX dispatches operations asynchronously to XLA. Without this # synchronisation barrier, memory_stats() may return stale values # that exclude memory from operations still "in flight". The # trade-off is a small allocation (1-element array) and a forced # sync that can slightly perturb workload timing at high sample # rates. At the default 1 s interval the overhead should be negligible. # # Exceptions are intentionally NOT caught here — they propagate to # _run_tracking_iteration which routes them through the # collector-health degradation path instead of recording a # synthetic zero-memory sample. if self._sync_sentinel is not None: self._sync_sentinel.block_until_ready() elif hasattr(jax.numpy, "zeros"): jax.numpy.zeros((), device=self._device).block_until_ready() stats = self._device.memory_stats() if stats: if "bytes_reserved" in stats: self._last_reserved_bytes = int(stats["bytes_reserved"]) else: self._last_reserved_bytes = None return int(stats.get("bytes_in_use", 0)) return 0 def _get_current_cpu_memory(self) -> int: if PSUTIL_AVAILABLE and psutil is not None: return int(psutil.Process().memory_info().rss) return 0 def _get_current_memory_bytes(self) -> int: """Return current memory usage in bytes (no unit conversion).""" if self._device is not None and self._device.platform != "cpu": return self._get_current_device_memory() return self._get_current_cpu_memory() def _get_current_memory(self) -> float: """Return current memory usage in MB.""" return self._get_current_memory_bytes() / (1024 * 1024) def _build_telemetry_event_record( self, *, timestamp: float, memory_bytes: int, event_type: str = "sample", context: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: sampling_interval_ms = int(round(self.sampling_interval * 1000)) # Use the real bytes_reserved value from JAX memory_stats() when # available, so downstream consumers can see actual XLA # preallocation overhead. When the value is unavailable we fall # back to aliasing allocated_bytes and flag it as approximate. reserved_bytes = self._last_reserved_bytes if reserved_bytes is None: reserved_bytes = memory_bytes is_approximate = True else: is_approximate = False session = self._ensure_session_summary() health_dict = self._collector_health.to_dict() meta = {**metadata, **health_dict} if metadata else health_dict if is_approximate: meta["allocator_reserved_approximate"] = True legacy = { "session_id": session.session_id, "timestamp": timestamp, "type": event_type, "memory_mb": memory_bytes / (1024 * 1024), "allocator_allocated_bytes": memory_bytes, "allocator_reserved_bytes": reserved_bytes, "device_id": self.device_index, "context": context, "metadata": meta, "collector": "stormlog.jax.memory_tracker", "sampling_interval_ms": sampling_interval_ms, "pid": self._cached_pid, "host": self._cached_hostname, "job_id": self.distributed_identity["job_id"], "rank": self.distributed_identity["rank"], "local_rank": self.distributed_identity["local_rank"], "world_size": self.distributed_identity["world_size"], } if self._device_bytes_limit is not None: legacy["device_total_bytes"] = self._device_bytes_limit event = telemetry_event_from_record( legacy, default_collector="stormlog.jax.memory_tracker", default_sampling_interval_ms=sampling_interval_ms, default_session_id=session.session_id, ) return telemetry_event_to_dict(event) def _append_event( self, *, timestamp: float, memory_bytes: int, event_type: str, context: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: record = self._build_telemetry_event_record( timestamp=timestamp, memory_bytes=memory_bytes, event_type=event_type, context=context, metadata=metadata, ) with self._lock: if len(self._events) == self.max_history: self._history_dropped_events += 1 self._events.append(record) self._append_to_telemetry_sink(record) if self._oom_flight_recorder.config.enabled: self._oom_flight_recorder.record_event(record) def _set_collector_health( self, *, status: str, telemetry_partial: bool, last_error: Optional[str] = None, consecutive_failures: int = 0, next_retry_epoch_s: Optional[float] = None, ) -> None: self._collector_health = CollectorHealthState( status=status, telemetry_partial=telemetry_partial, last_error=last_error, consecutive_failures=consecutive_failures, next_retry_epoch_s=next_retry_epoch_s, ) def _retry_collection_due(self, now: float) -> bool: retry_at = self._collector_health.next_retry_epoch_s return retry_at is None or now >= retry_at def _status_memory_value(self) -> int: return self._last_successful_memory_bytes or 0 def _transition_to_failure(self, timestamp: float, exc: BaseException) -> None: previous_health = self._collector_health consecutive_failures = previous_health.consecutive_failures + 1 retry_delay_s = collector_retry_delay_seconds( consecutive_failures, initial_delay_s=self._collector_retry_backoff_initial_s, factor=self._collector_retry_backoff_factor, max_delay_s=self._collector_retry_backoff_cap_s, ) next_retry_epoch_s = timestamp + retry_delay_s if retry_delay_s > 0 else None error_message = str(exc) self._set_collector_health( status=COLLECTOR_HEALTH_UNHEALTHY, telemetry_partial=True, last_error=error_message, consecutive_failures=consecutive_failures, next_retry_epoch_s=next_retry_epoch_s, ) if previous_health.status == COLLECTOR_HEALTH_HEALTHY: self._collector_failure_event_count += 1 self._append_event( timestamp=timestamp, memory_bytes=self._status_memory_value(), event_type="collector_degraded", context="Collector unavailable; telemetry paused until recovery.", metadata={ "collector_transition": "degraded", "collector_degraded_from": previous_health.status, "collector_degradation_reason": error_message, "collector_retry_delay_s": retry_delay_s, }, ) if self.enable_logging: logger.warning("Could not get memory usage: %s", error_message) def _transition_to_success(self, timestamp: float) -> None: previous_health = self._collector_health if previous_health.status != COLLECTOR_HEALTH_HEALTHY: previous_error = previous_health.last_error previous_failures = previous_health.consecutive_failures self._set_collector_health( status=COLLECTOR_HEALTH_HEALTHY, telemetry_partial=False, ) self._collector_failure_event_count += 1 self._append_event( timestamp=timestamp, memory_bytes=self._status_memory_value(), event_type="collector_recovered", context="Collector recovered; full telemetry sampling resumed.", metadata={ "collector_transition": "recovered", "collector_recovered_from": previous_health.status, "collector_previous_error": previous_error, "collector_previous_failure_count": previous_failures, }, ) else: self._set_collector_health( status=COLLECTOR_HEALTH_HEALTHY, telemetry_partial=False, ) def _run_tracking_iteration(self) -> None: current_time = time.time() if not self._retry_collection_due(current_time): return try: current_memory = self._get_current_memory_bytes() except Exception as exc: self._transition_to_failure(current_time, exc) return self._last_successful_memory_bytes = current_memory self._transition_to_success(current_time) with self._lock: if len(self._memory_usage) == self.max_history: self._history_dropped_samples += 1 self._memory_usage.append(current_memory) self._timestamps.append(current_time) self._total_samples_observed += 1 self._peak_memory_bytes = max(self._peak_memory_bytes, current_memory) if self._min_memory_bytes is None: self._min_memory_bytes = current_memory else: self._min_memory_bytes = min(self._min_memory_bytes, current_memory) self._sum_memory_bytes += current_memory self._append_event( timestamp=current_time, memory_bytes=current_memory, event_type="sample", ) if self.alert_threshold_mb is not None: current_memory_mb = current_memory / (1024 * 1024) if current_memory_mb > self.alert_threshold_mb: self._trigger_alert(current_memory_mb, current_time) def _tracking_loop(self) -> None: while not self._stop_event.is_set(): try: self._run_tracking_iteration() self._flush_telemetry_sink() self._stop_event.wait(self.sampling_interval) except Exception as e: if self.enable_logging: logger.error("Error in tracking loop: %s", e) self._flush_telemetry_sink(force=True) self._stop_event.wait(self.sampling_interval) def _trigger_alert(self, memory_mb: float, timestamp: float) -> None: alert = { "timestamp": timestamp, "memory_mb": memory_mb, "threshold_mb": self.alert_threshold_mb, "message": f"Memory usage {memory_mb:.1f} MB exceeds threshold {self.alert_threshold_mb:.1f} MB", } with self._lock: if len(self._alerts) == self.max_history: self._history_dropped_alerts += 1 self._alerts.append(alert) if self.enable_logging: logger.warning(alert["message"]) for callback in self.alert_callbacks: try: callback(alert) except Exception as e: if self.enable_logging: logger.error("Error in alert callback: %s", e)
[docs] def add_alert_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: self.alert_callbacks.append(callback)
[docs] def remove_alert_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: """Remove a previously registered alert callback.""" try: self.alert_callbacks.remove(callback) except ValueError: pass
[docs] def set_alert_threshold(self, threshold_mb: float) -> None: self.alert_threshold_mb = threshold_mb if self.enable_logging: logger.info("Updated alert threshold to %s MB", threshold_mb)
[docs] def check_alerts(self) -> bool: with self._lock: recent_alerts = [ alert for alert in self._alerts if time.time() - alert["timestamp"] < 10.0 ] return len(recent_alerts) > 0
[docs] def start_tracking(self) -> None: if self.tracking: if self.enable_logging: logger.warning("Tracking already started") return self._session_start_time = time.time() self._session_end_time = None self._session_summary = None self._phase_state.reset() self._ensure_telemetry_sink() self._stop_event.clear() with self._lock: self._reset_history() self._last_successful_memory_bytes = None self._set_collector_health( status=COLLECTOR_HEALTH_HEALTHY, telemetry_partial=False, ) self._ensure_session_summary() self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) self.tracking_thread.start() self.tracking = True self._append_event( timestamp=self._session_start_time, memory_bytes=self._status_memory_value(), event_type="start", context="Memory tracking started", ) if self.enable_logging: logger.info( "Started JAX memory tracking with %ss interval", self.sampling_interval )
[docs] def stop_tracking(self) -> TrackingResult: if not self.tracking: if self.enable_logging: logger.warning("Tracking not started") return self._create_empty_result() self.tracking = False self._stop_event.set() if self.tracking_thread: self.tracking_thread.join(timeout=5.0) self._session_end_time = time.time() self._append_event( timestamp=self._session_end_time, memory_bytes=self._status_memory_value(), event_type="stop", context="Memory tracking stopped", ) profile_path = None if self.save_device_profile_on_stop: profile_path = self.save_device_memory_profile_to_dir() self._close_telemetry_sink() if self._session_summary is not None: self._session_summary = finalize_session_summary( self._session_summary, ended_at_ns=now_ns(), ) self._phase_state.reset() result = self._create_tracking_result() result.device_memory_profile_path = profile_path if self.enable_logging: logger.info( "Stopped memory tracking. Peak usage: %.1f MB", result.peak_memory_bytes / (1024 * 1024), ) return result
[docs] def get_current_memory(self) -> float: """Get current memory usage in MB.""" try: return self._get_current_memory() except Exception: return float(self._last_successful_memory_bytes or 0) / (1024 * 1024)
[docs] def get_statistics(self) -> dict[str, Any]: with self._lock: retained_events = len(self._events) retained_samples = len(self._memory_usage) retained_alerts = len(self._alerts) peak_memory = self._peak_memory_bytes if self._total_samples_observed else 0 average_memory = ( self._sum_memory_bytes / self._total_samples_observed if self._total_samples_observed else 0 ) min_memory = ( self._min_memory_bytes if self._min_memory_bytes is not None and self._total_samples_observed else 0 ) collector_failure_event_count = self._collector_failure_event_count dropped_samples = self._history_dropped_samples dropped_events = self._history_dropped_events dropped_alerts = self._history_dropped_alerts tracking_start = self._session_start_time tracking_end = self._session_end_time tracking_duration = ( (tracking_end or time.time()) - tracking_start if isinstance(tracking_start, (int, float)) else 0.0 ) current_memory_mb = ( self._last_successful_memory_bytes / (1024 * 1024) if self._collector_health.status == COLLECTOR_HEALTH_HEALTHY and self._last_successful_memory_bytes is not None else None ) return { "current_memory_mb": current_memory_mb, "peak_memory_mb": peak_memory / (1024 * 1024), "average_memory_mb": average_memory / (1024 * 1024), "min_memory_mb": min_memory / (1024 * 1024), "collector_failure_event_count": collector_failure_event_count, "total_events": retained_events, "tracking_duration_seconds": tracking_duration, "history_window_limit": self.max_history, "history_retained_samples": retained_samples, "history_dropped_samples": dropped_samples, "history_retained_events": retained_events, "history_dropped_events": dropped_events, "history_retained_alerts": retained_alerts, "history_dropped_alerts": dropped_alerts, **self._last_sink_diagnostics, "session_id": ( self._session_summary.session_id if self._session_summary is not None else None ), "session_status": ( self._session_summary.status if self._session_summary is not None else None ), **self._collector_health.to_dict(), "oom_flight_recorder_enabled": self._oom_flight_recorder.config.enabled, "last_oom_dump_path": self._last_oom_dump_path, }
[docs] def get_tracking_results(self) -> TrackingResult: """Get current tracking results without stopping.""" return self._create_tracking_result()
def _create_tracking_result(self) -> TrackingResult: with self._lock: result_data = self._tracking_result_data() if ( not result_data.retained_memory_usage and not result_data.retained_events and not result_data.retained_alerts ): return self._create_empty_result() session_start = self._session_start_time session_end = self._session_end_time if session_start is None: session_start = ( result_data.retained_timestamps[0] if result_data.retained_timestamps else time.time() ) if session_end is None: session_end = ( result_data.retained_timestamps[-1] if result_data.retained_timestamps else time.time() ) return TrackingResult( start_time=session_start, end_time=session_end, samples_collected=result_data.total_samples_observed, memory_usage=result_data.retained_memory_usage, timestamps=result_data.retained_timestamps, peak_memory_bytes=( result_data.peak_memory if result_data.total_samples_observed else 0 ), average_memory_bytes=( int(result_data.sum_memory / result_data.total_samples_observed) if result_data.total_samples_observed else 0 ), min_memory_bytes=( result_data.min_memory if result_data.total_samples_observed else 0 ), alert_count=len(result_data.retained_alerts) + result_data.dropped_alerts, telemetry_events=result_data.retained_events, session_summary=self._session_summary, history_window_limit=self.max_history, history_retained_samples=len(result_data.retained_memory_usage), history_dropped_samples=result_data.dropped_samples, history_retained_events=len(result_data.retained_events), history_dropped_events=result_data.dropped_events, history_retained_alerts=len(result_data.retained_alerts), history_dropped_alerts=result_data.dropped_alerts, ) def _create_empty_result(self) -> TrackingResult: current_time = time.time() start_time = self._session_start_time or current_time end_time = self._session_end_time or start_time return TrackingResult( start_time=start_time, end_time=end_time, samples_collected=0, memory_usage=[], timestamps=[], peak_memory_bytes=0, average_memory_bytes=0, min_memory_bytes=0, alert_count=0, session_summary=self._session_summary, history_window_limit=self.max_history, ) # -- Telemetry Sink Management ----------------------------------------- def _append_to_telemetry_sink(self, record: Dict[str, Any]) -> None: if self._telemetry_sink is None: return try: self._telemetry_sink.append(record) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("append", exc) def _flush_telemetry_sink(self, *, force: bool = False) -> None: if self._telemetry_sink is None: return try: self._telemetry_sink.flush(force=force) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("flush", exc) def _close_telemetry_sink(self) -> None: if self._telemetry_sink is None: return try: self._close_sink_with_status( self._telemetry_sink, SESSION_STATUS_COMPLETED, ) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("close", exc) else: self._telemetry_sink = None def _disable_telemetry_sink(self, operation: str, exc: Exception) -> None: sink = self._telemetry_sink if sink is None: return self._telemetry_sink = None logger.warning( "Disabling JAX telemetry sink after %s failure: %s", operation, exc, ) if self._session_summary is not None: self._session_summary = update_session_summary( self._session_summary, status=SESSION_STATUS_INCOMPLETE, ended_at_ns=now_ns(), ) try: self._close_sink_with_status(sink, SESSION_STATUS_INCOMPLETE) if hasattr(sink, "get_diagnostics"): self._last_sink_diagnostics = sink.get_diagnostics() except Exception as close_exc: logger.debug( "JAX telemetry sink close failed after %s error: %s", operation, close_exc, ) @staticmethod def _close_sink_with_status(sink: Any, status: str) -> None: try: sink.close(session_status=status) except TypeError: sink.close() # -- Phases ------------------------------------------------------------
[docs] def enter_phase( self, name: str, *, metadata: Optional[Dict[str, Any]] = None ) -> PhaseHandle: if not self.tracking: raise RuntimeError("Tracking must be active before entering a phase.") session = self._ensure_session_summary() token, boundary = self._phase_state.enter( session_id=session.session_id, rank=self.distributed_identity["rank"], name=name, attrs=metadata, ) self._append_event( timestamp=time.time(), memory_bytes=self._status_memory_value(), event_type=boundary.event_type, context=boundary.context, metadata=boundary.metadata, ) return PhaseHandle( scope_id=boundary.scope_id, name=name, path=boundary.path, close_callback=lambda: self._emit_phase_exit(token), )
[docs] @contextmanager def phase( self, name: str, *, metadata: Optional[Dict[str, Any]] = None ) -> Iterator[PhaseHandle]: handle = self.enter_phase(name, metadata=metadata) try: yield handle finally: handle.close()
def _emit_phase_exit(self, token: PhaseToken) -> None: boundary = self._phase_state.exit(token) self._append_event( timestamp=time.time(), memory_bytes=self._status_memory_value(), event_type=boundary.event_type, context=boundary.context, metadata=boundary.metadata, ) # -- Context manager protocol ------------------------------------------ def __enter__(self) -> "MemoryTracker": self.start_tracking() return self def __exit__(self, *exc: Any) -> None: self.stop_tracking() # -- OOM Flight Recorder ----------------------------------------------- @property def last_oom_dump_path(self) -> Optional[str]: """Path to the most recent OOM dump bundle, or None.""" return self._last_oom_dump_path @last_oom_dump_path.setter def last_oom_dump_path(self, value: Optional[str]) -> None: self._last_oom_dump_path = value
[docs] def handle_exception( self, exc: BaseException, context: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """Capture OOM diagnostics for recognized OOM exceptions.""" classification = classify_oom_exception(exc) if not classification.is_oom or classification.reason is None: return None if not self._oom_flight_recorder.config.enabled: return None dump_metadata: Dict[str, Any] = { "tracker_stats": self.get_statistics(), "sampling_interval_s": self.sampling_interval, "session_id": None, "job_id": self.distributed_identity["job_id"], "rank": self.distributed_identity["rank"], "local_rank": self.distributed_identity["local_rank"], "world_size": self.distributed_identity["world_size"], } if metadata: dump_metadata.update(metadata) current_memory = self._status_memory_value() reserved_bytes = self._last_reserved_bytes is_approximate = False if reserved_bytes is None: reserved_bytes = current_memory is_approximate = True dump_metadata.update( { "sample_allocated_bytes": current_memory, "sample_reserved_bytes": reserved_bytes, "sample_device_id": self.device_index, "sample_total_bytes": self._device_bytes_limit, } ) if is_approximate: dump_metadata["allocator_reserved_approximate"] = True self._append_event( timestamp=time.time(), memory_bytes=current_memory, event_type="error", context=f"OOM detected ({classification.reason})", metadata={"oom_reason": classification.reason}, ) session_summary = self._session_summary dump_metadata["tracker_stats"] = self.get_statistics() dump_metadata["session_id"] = ( session_summary.session_id if session_summary is not None else None ) try: dump_path = self._oom_flight_recorder.dump( reason=classification.reason, exception=exc, context=context, backend="jax", metadata=dump_metadata, session_summary=session_summary, ) except Exception as dump_exc: logger.debug("OOM flight recorder dump failed: %s", dump_exc) return None # Enrich the dump with a JAX device memory profile artifact if dump_path is not None: profile_path = self.save_device_memory_profile_to_dir( output_dir=dump_path, ) if profile_path: try: manifest_path = Path(dump_path) / "manifest.json" if manifest_path.exists(): manifest = json.loads(manifest_path.read_text(encoding="utf-8")) files = list(manifest.get("files", [])) profile_name = Path(profile_path).name if profile_name not in files: files.append(profile_name) # Automatically generate the HTML visualization try: from .attributed_viz import render_jax_attributed_html from .pprof_parser import parse_jax_memory_profile html_path = str( Path(dump_path) / "jax-device-memory-graph.html" ) profile_data = parse_jax_memory_profile(profile_path) render_jax_attributed_html(profile_data, html_path) html_name = Path(html_path).name if html_name not in files: files.append(html_name) manifest["jax_device_profile_html"] = html_name except Exception as viz_exc: logger.debug( "Could not generate HTML memory graph: %s", viz_exc ) manifest["files"] = files manifest["jax_device_profile"] = True manifest_path.write_text( json.dumps(manifest, indent=2), encoding="utf-8", ) except Exception as enrich_exc: logger.debug( "Could not enrich OOM manifest with profile: %s", enrich_exc, ) self._last_oom_dump_path = dump_path return dump_path
[docs] @contextmanager def capture_oom( self, context: str = "runtime", metadata: Optional[Dict[str, Any]] = None, ) -> Iterator[None]: """Capture an OOM diagnostic bundle if the wrapped block raises an OOM.""" try: yield except Exception as exc: dump_path = self.handle_exception(exc, context=context, metadata=metadata) if dump_path: logger.error("OOM flight recorder dump saved to: %s", dump_path) raise
[docs] def trigger_oom_dump( self, exception: BaseException, context: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """Manually trigger an OOM diagnostic dump bundle.""" return self.handle_exception(exception, context=context, metadata=metadata)
# -- Device Memory Profile Export --------------------------------------
[docs] def save_device_memory_profile(self, output_path: str) -> bool: """Save a JAX device memory profile to the given path. .. note:: This method depends on ``jax.profiler.save_device_memory_profile`` which is only available on GPU/TPU backends with JAX >= 0.4.1. On CPU-only installs or older JAX versions the call is a no-op and returns ``False``. The availability is checked at runtime via ``hasattr`` guards so no import error is raised. """ if not JAX_AVAILABLE: return False try: if hasattr(jax, "profiler") and hasattr( jax.profiler, "save_device_memory_profile" ): jax.profiler.save_device_memory_profile(output_path) if self.enable_logging: logger.info("Saved JAX device memory profile to %s", output_path) return True else: logger.warning( "jax.profiler.save_device_memory_profile is not available." ) return False except Exception as exc: logger.error("Failed to save JAX device memory profile: %s", exc) return False
[docs] def save_device_memory_profile_to_dir( self, output_dir: Optional[str] = None ) -> Optional[str]: """Save a JAX device memory profile to a directory with an auto-generated filename.""" if output_dir is None: output_dir = os.getcwd() try: os.makedirs(output_dir, exist_ok=True) timestamp = int(time.time()) filename = f"jax-device-memory-{timestamp}.prof" filepath = os.path.join(output_dir, filename) if self.save_device_memory_profile(filepath): return filepath except Exception as exc: logger.error("Failed to setup directory for memory profile: %s", exc) return None
[docs] class MemoryWatchdog: """Automatic memory management and cleanup for JAX workloads.""" def __init__( self, max_memory_mb: float = 8000, cleanup_threshold_mb: float = 6000, check_interval: float = 5.0, device_index: int = 0, ): if not JAX_AVAILABLE: raise ImportError("JAX not available.") self.max_memory_mb = max_memory_mb self.cleanup_threshold_mb = cleanup_threshold_mb self.check_interval = check_interval self._device = None try: devices = jax.local_devices() if device_index < len(devices): self._device = devices[device_index] except Exception: pass self.active = False self.watchdog_thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self.cleanup_callbacks: List[Callable[[], None]] = [] logger.info("JAX Memory Watchdog initialized with %s MB limit", max_memory_mb)
[docs] def add_cleanup_callback(self, callback: Callable[[], None]) -> None: """Add cleanup callback function.""" self.cleanup_callbacks.append(callback)
def _get_memory_usage(self) -> float: if self._device is None: return 0.0 try: stats = self._device.memory_stats() if stats: return int(stats.get("bytes_in_use", 0)) / (1024 * 1024) return 0.0 except Exception as exc: logger.debug("Watchdog could not get device memory: %s", exc) return 0.0 def _cleanup_memory(self) -> None: try: if hasattr(jax, "clear_caches"): jax.clear_caches() import gc gc.collect() for callback in self.cleanup_callbacks: try: callback() except Exception as e: logger.error("Error in cleanup callback: %s", e) logger.info("Performed JAX memory cleanup") except Exception as e: logger.error("Error during JAX memory cleanup: %s", e) def _watchdog_loop(self) -> None: while not self._stop_event.is_set(): try: current_memory = self._get_memory_usage() if current_memory > self.max_memory_mb: logger.warning( "Memory usage %.1f MB exceeds limit %s MB - forcing cleanup", current_memory, self.max_memory_mb, ) self._cleanup_memory() elif current_memory > self.cleanup_threshold_mb: logger.info( "Memory usage %.1f MB above threshold %s MB - performing cleanup", current_memory, self.cleanup_threshold_mb, ) self._cleanup_memory() self._stop_event.wait(self.check_interval) except Exception as e: logger.error("Error in watchdog loop: %s", e) break
[docs] def start(self) -> None: """Start memory watchdog.""" if self.active: logger.warning("Watchdog already active") return self.active = True self._stop_event.clear() self.watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True) self.watchdog_thread.start() logger.info("Started JAX memory watchdog")
[docs] def stop(self) -> None: """Stop memory watchdog.""" if not self.active: return self.active = False self._stop_event.set() if self.watchdog_thread: self.watchdog_thread.join(timeout=5.0) logger.info("Stopped JAX memory watchdog")
[docs] def force_cleanup(self, aggressive: bool = False) -> None: """Force immediate memory cleanup. Args: aggressive: When *True*, also delete all live JAX arrays reachable via ``jax.live_arrays()`` (if available) before running garbage collection. Use with caution — this can invalidate arrays still referenced by user code. """ if aggressive: try: if hasattr(jax, "live_arrays"): for arr in jax.live_arrays(): try: arr.delete() except Exception: pass except Exception as e: logger.debug("Aggressive cleanup of live arrays failed: %s", e) self._cleanup_memory()
# Aliases for backward compatibility JAXMemoryTracker = MemoryTracker