Source code for stormlog.tracker

"""Real-time memory tracking and monitoring."""

import json
import logging
import os
import socket
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import torch

from .collector_health import (
    COLLECTOR_HEALTH_DEGRADED,
    COLLECTOR_HEALTH_HEALTHY,
    COLLECTOR_HEALTH_UNHEALTHY,
    CollectorHealthState,
    collector_retry_delay_seconds,
)
from .cuda_native_debug import (
    DEFAULT_TRACE_ALLOC_MAX_ENTRIES,
    capture_cuda_snapshot_artifacts,
    cuda_memory_history_supported,
    start_cuda_memory_history,
    stop_cuda_memory_history,
)
from .device_collectors import (
    DeviceMemorySample,
    DeviceMemorySampleResult,
    _resolve_device,
    build_device_memory_collector,
    detect_torch_runtime_backend,
)
from .oom_flight_recorder import (
    OOMFlightRecorder,
    OOMFlightRecorderConfig,
    classify_oom_exception,
)
from .phases import PhaseHandle, PhaseRecorder, PhaseToken
from .session import (
    SESSION_STATUS_COMPLETED,
    SESSION_STATUS_INCOMPLETE,
    SESSION_STATUS_RUNNING,
    SessionSummary,
    create_session_summary,
    finalize_session_summary,
    now_ns,
    update_session_summary,
)
from .telemetry import (
    resolve_distributed_identity,
    telemetry_event_from_record,
    telemetry_event_to_dict,
)
from .telemetry_sink import AppendOnlyTelemetrySink, TelemetrySinkConfig
from .utils import format_bytes, get_gpu_info

logger = logging.getLogger(__name__)


[docs] @dataclass class TrackingEvent: """Represents a memory tracking event.""" timestamp: float event_type: str # 'allocation', 'deallocation', 'peak', 'warning', 'error' memory_allocated: int memory_reserved: int memory_change: int device_id: int session_id: Optional[str] = None context: Optional[str] = None job_id: Optional[str] = None rank: int = 0 local_rank: int = 0 world_size: int = 1 metadata: Optional[Dict[str, Any]] = None active_memory: Optional[int] = None inactive_memory: Optional[int] = None device_used: Optional[int] = None device_free: Optional[int] = None device_total: Optional[int] = None backend: str = "cuda"
[docs] class MemoryTracker: """Real-time memory tracker with alerts and monitoring.""" def __init__( self, device: Optional[Union[str, int, torch.device]] = None, sampling_interval: float = 0.1, max_events: int = 10000, enable_alerts: bool = True, enable_oom_flight_recorder: bool = False, oom_dump_dir: str = "oom_dumps", oom_buffer_size: Optional[int] = None, oom_max_dumps: int = 5, oom_max_total_mb: int = 256, job_id: Optional[str] = None, rank: Optional[int] = None, local_rank: Optional[int] = None, world_size: Optional[int] = None, enable_native_cuda_history: bool = False, native_history_max_entries: int = DEFAULT_TRACE_ALLOC_MAX_ENTRIES, telemetry_sink_config: Optional[TelemetrySinkConfig] = None, ): """ Initialize the memory tracker. Args: device: GPU device to track sampling_interval: Sampling interval in seconds max_events: Maximum number of events to keep in memory enable_alerts: Whether to enable memory alerts enable_oom_flight_recorder: Enable automatic OOM dump artifacts oom_dump_dir: Directory used for OOM dump bundles oom_buffer_size: Event ring-buffer size used for OOM dumps oom_max_dumps: Maximum number of retained OOM dump bundles oom_max_total_mb: Maximum retained OOM dump storage in MB """ if sampling_interval <= 0: raise ValueError("sampling_interval must be > 0") if max_events <= 0: raise ValueError("max_events must be >= 1") if native_history_max_entries <= 0: raise ValueError("native_history_max_entries must be >= 1") self.device = self._setup_device(device) self.collector = build_device_memory_collector(self.device) self.backend = self.collector.name() self.collector_capabilities = self.collector.capabilities() self.sampling_interval = sampling_interval self.max_events = max_events self.enable_alerts = enable_alerts self.enable_native_cuda_history = enable_native_cuda_history self.native_history_max_entries = native_history_max_entries self._telemetry_sink_config = telemetry_sink_config self._telemetry_sink = ( AppendOnlyTelemetrySink(telemetry_sink_config) if telemetry_sink_config is not None else None ) self.last_oom_dump_path: Optional[str] = None self.distributed_identity = resolve_distributed_identity( job_id=job_id, rank=rank, local_rank=local_rank, world_size=world_size, env=os.environ, ) self.session_source = "stormlog.tracker" self._session_summary: Optional[SessionSummary] = None recorder_buffer_size = ( oom_buffer_size if oom_buffer_size is not None else max_events ) if recorder_buffer_size <= 0: recorder_buffer_size = max_events self._oom_flight_recorder = OOMFlightRecorder( OOMFlightRecorderConfig( enabled=enable_oom_flight_recorder, dump_dir=oom_dump_dir, buffer_size=recorder_buffer_size, max_dumps=oom_max_dumps, max_total_mb=oom_max_total_mb, ) ) # Tracking state self.events: deque[TrackingEvent] = deque(maxlen=max_events) self._history_dropped_events = 0 self.is_tracking = False self._tracking_thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self._collector_health = CollectorHealthState() self._last_observed_sample: Optional[DeviceMemorySample] = None self._last_sink_diagnostics: Dict[str, int] = self._empty_sink_diagnostics() self._collector_retry_backoff_initial_s = 1.0 self._collector_retry_backoff_factor = 2.0 self._collector_retry_backoff_cap_s = 30.0 self._phase_state = PhaseRecorder() # Memory thresholds for alerts self.thresholds: Dict[str, float] = { "memory_warning_percent": 80.0, # Warn at 80% memory usage "memory_critical_percent": 95.0, # Critical at 95% memory usage "memory_leak_threshold": float(100 * 1024 * 1024), # 100MB growth "fragmentation_threshold": 0.3, # 30% fragmentation } # Alert callbacks self.alert_callbacks: List[Callable[[TrackingEvent], None]] = [] # Statistics self.stats: Dict[str, Any] = { "peak_memory": 0, "total_allocations": 0, "total_deallocations": 0, "total_allocation_bytes": 0, "total_deallocation_bytes": 0, "alert_count": 0, "tracking_start_time": None, "last_memory_check": 0, } # Get memory limits with backend-aware fallback. self.gpu_info = get_gpu_info(self.device) if self.device.type == "cuda" else {} initial_result = self.collector.sample_with_diagnostics() initial_sample = initial_result.sample if initial_sample is not None: self._last_observed_sample = initial_sample total_memory = ( initial_sample.total_bytes if initial_sample is not None else None ) if total_memory is None: fallback_total = self.gpu_info.get("total_memory", 0) total_memory = ( int(fallback_total) if isinstance(fallback_total, (int, float)) else 0 ) self.total_memory = int(total_memory) @staticmethod def _empty_sink_diagnostics() -> Dict[str, int]: return { "rollover_count": 0, "pruned_segment_count": 0, "pruned_bytes": 0, "final_retained_files": 0, "final_retained_bytes": 0, } def _reset_collector_session_state(self) -> None: """Reset per-session collector state before a fresh tracking run.""" self._set_collector_health( status=COLLECTOR_HEALTH_HEALTHY, telemetry_partial=False, ) self._last_observed_sample = None self.stats["last_memory_check"] = 0 def _reset_tracking_state_for_new_session(self) -> None: """Clear per-session in-memory state before starting a new run.""" self.events.clear() self._history_dropped_events = 0 self._last_sink_diagnostics = self._empty_sink_diagnostics() self.last_oom_dump_path = None self.stats.update( { "peak_memory": 0, "total_allocations": 0, "total_deallocations": 0, "total_allocation_bytes": 0, "total_deallocation_bytes": 0, "alert_count": 0, "tracking_start_time": None, "last_memory_check": 0, } ) self._oom_flight_recorder.clear() def _open_session(self) -> SessionSummary: """Create the active runtime session summary for a tracking run.""" if self._session_summary is not None: return self._session_summary summary = create_session_summary( source=self.session_source, status=SESSION_STATUS_RUNNING, started_at_ns=now_ns(), host=socket.gethostname(), pid=os.getpid(), job_id=self.distributed_identity["job_id"], rank=self.distributed_identity["rank"], local_rank=self.distributed_identity["local_rank"], world_size=self.distributed_identity["world_size"], ) self._session_summary = summary if self._telemetry_sink is not None and hasattr( self._telemetry_sink, "start_session" ): self._session_summary = self._telemetry_sink.start_session(summary) return self._session_summary def _ensure_telemetry_sink(self) -> None: if self._telemetry_sink is None and self._telemetry_sink_config is not None: self._telemetry_sink = AppendOnlyTelemetrySink(self._telemetry_sink_config)
[docs] def get_session_summary(self) -> Optional[SessionSummary]: """Return the current or most recent tracking session summary.""" return self._session_summary
@property def oom_buffer_size(self) -> int: """Resolved OOM ring-buffer size.""" return self._oom_flight_recorder.config.buffer_size def _setup_device( self, device: Union[str, int, torch.device, None] ) -> torch.device: """Setup and validate the device for tracking.""" resolved_device = _resolve_device(device) if resolved_device.type not in {"cuda", "mps"}: raise ValueError( "Only CUDA/ROCm or MPS devices are supported for GPU memory tracking" ) if resolved_device.type == "cuda": if not torch.cuda.is_available(): raise RuntimeError("CUDA/ROCm backend is not available in this runtime") device_index = ( resolved_device.index if resolved_device.index is not None else torch.cuda.current_device() ) if device_index >= torch.cuda.device_count(): raise ValueError(f"Device {resolved_device} is not available") return torch.device(f"cuda:{device_index}") if detect_torch_runtime_backend() != "mps": raise RuntimeError("MPS backend is not available in this runtime") return resolved_device def _safe_sample(self) -> DeviceMemorySample: """Collect one backend sample for ad-hoc diagnostics with fallback values.""" result = self.collector.sample_with_diagnostics() if result.sample is not None: return result.sample logger.debug( "Could not sample %s memory: %s", self.backend, result.core_error or "unknown collector error", ) return self._empty_sample() def _empty_sample(self) -> DeviceMemorySample: """Build a zeroed sample for status-only events without live telemetry.""" device_id = 0 if self.device.type == "cuda": try: device_id = ( self.device.index if self.device.index is not None else torch.cuda.current_device() ) except Exception: device_id = 0 return DeviceMemorySample( allocated_bytes=0, reserved_bytes=0, used_bytes=0, free_bytes=None, total_bytes=None, active_bytes=None, inactive_bytes=None, device_id=device_id, ) def _event_sample(self, sample: Optional[DeviceMemorySample]) -> DeviceMemorySample: if sample is not None: return sample if self._last_observed_sample is not None: return self._last_observed_sample return self._empty_sample() @staticmethod def _collector_error_message(result: DeviceMemorySampleResult) -> Optional[str]: if result.core_error: return result.core_error unique_messages = list(dict.fromkeys(result.errors.values())) if not unique_messages: return None return "; ".join(unique_messages) def _set_collector_health( self, *, status: str, telemetry_partial: bool, partial_fields: tuple[str, ...] = (), last_error: Optional[str] = None, consecutive_failures: int = 0, next_retry_epoch_s: Optional[float] = None, ) -> None: self._collector_health = CollectorHealthState( status=status, telemetry_partial=telemetry_partial, partial_fields=partial_fields, last_error=last_error, consecutive_failures=consecutive_failures, next_retry_epoch_s=next_retry_epoch_s, ) def _retry_collection_due(self, now: float) -> bool: retry_at = self._collector_health.next_retry_epoch_s return retry_at is None or now >= retry_at def _transition_to_core_failure( self, result: DeviceMemorySampleResult, *, event_time: float, ) -> None: previous_health = self._collector_health consecutive_failures = previous_health.consecutive_failures + 1 retry_delay_s = collector_retry_delay_seconds( consecutive_failures, initial_delay_s=self._collector_retry_backoff_initial_s, factor=self._collector_retry_backoff_factor, max_delay_s=self._collector_retry_backoff_cap_s, ) next_retry_epoch_s = event_time + retry_delay_s if retry_delay_s > 0 else None error_message = self._collector_error_message(result) or "Collector unavailable" self._set_collector_health( status=COLLECTOR_HEALTH_UNHEALTHY, telemetry_partial=True, last_error=error_message, consecutive_failures=consecutive_failures, next_retry_epoch_s=next_retry_epoch_s, ) if previous_health.status == COLLECTOR_HEALTH_HEALTHY: self._add_event( "collector_degraded", 0, "Collector unavailable; telemetry paused until recovery.", metadata={ "collector_transition": "degraded", "collector_degraded_from": previous_health.status, "collector_degradation_reason": error_message, "collector_retry_delay_s": retry_delay_s, }, ) def _transition_to_sampled_state( self, result: DeviceMemorySampleResult, *, sample: DeviceMemorySample, ) -> bool: previous_health = self._collector_health is_partial = result.is_partial error_message = self._collector_error_message(result) if is_partial: self._set_collector_health( status=COLLECTOR_HEALTH_DEGRADED, telemetry_partial=True, partial_fields=result.partial_fields, last_error=error_message, consecutive_failures=0, next_retry_epoch_s=None, ) if previous_health.status == COLLECTOR_HEALTH_HEALTHY: self._add_event( "collector_degraded", 0, "Collector degraded; telemetry is partial.", metadata={ "collector_transition": "degraded", "collector_degraded_from": previous_health.status, "collector_degradation_reason": error_message, }, sample=sample, ) return True recovered = previous_health.status != COLLECTOR_HEALTH_HEALTHY previous_error = previous_health.last_error previous_failures = previous_health.consecutive_failures previous_status = previous_health.status self._set_collector_health( status=COLLECTOR_HEALTH_HEALTHY, telemetry_partial=False, ) if recovered: self._add_event( "collector_recovered", 0, "Collector recovered; full telemetry sampling resumed.", metadata={ "collector_transition": "recovered", "collector_recovered_from": previous_status, "collector_previous_error": previous_error, "collector_previous_failure_count": previous_failures, }, sample=sample, ) return False def _run_tracking_iteration(self, last_allocated: int) -> int: """Run one collection iteration, preserving health state across failures.""" now = time.time() if not self._retry_collection_due(now): return last_allocated result = self.collector.sample_with_diagnostics() if result.sample is None: self._transition_to_core_failure(result, event_time=now) return last_allocated sample = result.sample self._last_observed_sample = sample current_allocated = sample.allocated_bytes current_reserved = sample.reserved_bytes memory_change = current_allocated - last_allocated is_partial = self._transition_to_sampled_state(result, sample=sample) self.stats["last_memory_check"] = now if current_allocated > self.stats["peak_memory"]: self.stats["peak_memory"] = current_allocated self._add_event( "peak", memory_change, f"New peak memory: {format_bytes(current_allocated)}", sample=sample, ) if memory_change > 0: self.stats["total_allocations"] += 1 self.stats["total_allocation_bytes"] += memory_change self._add_event( "allocation", memory_change, f"Memory allocated: {format_bytes(memory_change)}", sample=sample, ) elif memory_change < 0: self.stats["total_deallocations"] += 1 self.stats["total_deallocation_bytes"] += abs(memory_change) self._add_event( "deallocation", memory_change, f"Memory freed: {format_bytes(abs(memory_change))}", sample=sample, ) if self.enable_alerts: self._check_alerts( current_allocated, current_reserved, memory_change, sample=sample, ) partial_fields = ", ".join(result.partial_fields) sample_context = ( f"Collected partial telemetry sample ({partial_fields})." if is_partial else "Collected telemetry sample." ) self._add_event( "sample", 0, sample_context, sample=sample, ) return current_allocated
[docs] def start_tracking(self) -> None: """Start real-time memory tracking.""" if self.is_tracking: return self._reset_collector_session_state() self._reset_tracking_state_for_new_session() self._session_summary = None self._phase_state.reset() self._ensure_telemetry_sink() self._stop_event.clear() self.stats["tracking_start_time"] = time.time() self._open_session() self._tracking_thread = threading.Thread(target=self._tracking_loop) self._tracking_thread.daemon = True self._tracking_thread.start() self.is_tracking = True # Add initial event self._add_event("start", 0, "Memory tracking started")
[docs] def stop_tracking(self) -> None: """Stop real-time memory tracking.""" if not self.is_tracking: return self.is_tracking = False self._stop_event.set() if self._tracking_thread: self._tracking_thread.join(timeout=1.0) # Add final event self._add_event("stop", 0, "Memory tracking stopped") self._close_telemetry_sink() if self._session_summary is not None: self._session_summary = finalize_session_summary( self._session_summary, ended_at_ns=now_ns(), ) self._phase_state.reset()
def _tracking_loop(self) -> None: """Main tracking loop running in background thread.""" last_allocated = 0 while not self._stop_event.wait(self.sampling_interval): try: last_allocated = self._run_tracking_iteration(last_allocated) self._flush_telemetry_sink() except Exception as exc: self._add_event("error", 0, f"Tracking error: {str(exc)}") self._flush_telemetry_sink(force=True) time.sleep(1.0) # Back off on unexpected tracker logic errors def _add_event( self, event_type: str, memory_change: int, context: str, metadata: Optional[Dict[str, Any]] = None, sample: Optional[DeviceMemorySample] = None, ) -> None: """Add a tracking event.""" snapshot = self._event_sample(sample) current_allocated = snapshot.allocated_bytes current_reserved = snapshot.reserved_bytes event_metadata = dict(metadata or {}) event_metadata.update(self._collector_health.to_dict()) event = TrackingEvent( timestamp=time.time(), event_type=event_type, memory_allocated=current_allocated, memory_reserved=current_reserved, memory_change=memory_change, device_id=snapshot.device_id, session_id=( self._open_session().session_id if self._session_summary is None else self._session_summary.session_id ), context=context, job_id=self.distributed_identity["job_id"], rank=self.distributed_identity["rank"], local_rank=self.distributed_identity["local_rank"], world_size=self.distributed_identity["world_size"], metadata=event_metadata, active_memory=snapshot.active_bytes, inactive_memory=snapshot.inactive_bytes, device_used=snapshot.used_bytes, device_free=snapshot.free_bytes, device_total=snapshot.total_bytes, backend=self.backend, ) if len(self.events) == self.max_events: self._history_dropped_events += 1 self.events.append(event) self._oom_flight_recorder.record_event(self._tracking_event_payload(event)) self._append_to_telemetry_sink(event) # Trigger callbacks for alerts if event_type in ["warning", "critical", "error"]: self.stats["alert_count"] += 1 for callback in self.alert_callbacks: try: callback(event) except Exception as exc: logger.debug("Alert callback error (suppressed): %s", exc)
[docs] def enter_phase( self, name: str, *, metadata: Optional[Dict[str, Any]] = None ) -> PhaseHandle: """Enter one structured workload phase while tracking is active.""" if not self.is_tracking: raise RuntimeError("Tracking must be active before entering a phase.") session = self._open_session() token, boundary = self._phase_state.enter( session_id=session.session_id, rank=self.distributed_identity["rank"], name=name, attrs=metadata, ) self._add_event( boundary.event_type, 0, boundary.context, metadata=boundary.metadata, ) return PhaseHandle( scope_id=boundary.scope_id, name=name, path=boundary.path, close_callback=lambda: self._emit_phase_exit(token), )
[docs] @contextmanager def phase(self, name: str, *, metadata: Optional[Dict[str, Any]] = None) -> Any: """Context manager that emits structured phase enter and exit records.""" handle = self.enter_phase(name, metadata=metadata) try: yield handle finally: self._close_phase_handle(handle)
def _close_phase_handle(self, handle: PhaseHandle) -> None: handle.close() def _emit_phase_exit(self, token: PhaseToken) -> None: boundary = self._phase_state.exit(token) self._add_event( boundary.event_type, 0, boundary.context, metadata=boundary.metadata, ) def _check_alerts( self, allocated: int, reserved: int, change: int, *, sample: Optional[DeviceMemorySample] = None, ) -> bool: """Check for memory alerts and warnings.""" if self.total_memory == 0: return False # Memory usage percentage usage_percent = (allocated / self.total_memory) * 100 emitted = False # Critical memory usage if usage_percent >= self.thresholds["memory_critical_percent"]: self._add_event( "critical", change, f"CRITICAL: Memory usage at {usage_percent:.1f}%", {"usage_percent": usage_percent}, sample=sample, ) emitted = True # Warning memory usage elif usage_percent >= self.thresholds["memory_warning_percent"]: self._add_event( "warning", change, f"WARNING: Memory usage at {usage_percent:.1f}%", {"usage_percent": usage_percent}, sample=sample, ) emitted = True # Large allocation warning if change > self.thresholds["memory_leak_threshold"]: self._add_event( "warning", change, f"Large allocation detected: {format_bytes(change)}", {"large_allocation": True}, sample=sample, ) emitted = True # Fragmentation warning if reserved > 0: fragmentation = (reserved - allocated) / reserved if fragmentation > self.thresholds["fragmentation_threshold"]: self._add_event( "warning", change, f"High fragmentation: {fragmentation:.1%}", {"fragmentation": fragmentation}, sample=sample, ) emitted = True return emitted @staticmethod def _tracking_event_payload(event: TrackingEvent) -> Dict[str, Any]: """Serialize a TrackingEvent into a stable JSON-safe payload.""" return { "timestamp": event.timestamp, "event_type": event.event_type, "session_id": event.session_id, "memory_allocated": event.memory_allocated, "memory_reserved": event.memory_reserved, "memory_change": event.memory_change, "device_id": event.device_id, "context": event.context, "job_id": event.job_id, "rank": event.rank, "local_rank": event.local_rank, "world_size": event.world_size, "metadata": dict(event.metadata or {}), "active_memory": event.active_memory, "inactive_memory": event.inactive_memory, "device_used": event.device_used, "device_free": event.device_free, "device_total": event.device_total, "backend": event.backend, } def _telemetry_record_from_event(self, event: TrackingEvent) -> Dict[str, Any]: host = socket.gethostname() pid = os.getpid() sampling_interval_ms = int(round(self.sampling_interval * 1000)) session_id = event.session_id or self._open_session().session_id default_collector = str( self.collector_capabilities.get( "telemetry_collector", "stormlog.cuda_tracker" ) ) capability_metadata = { "backend": self.backend, "supports_device_total": bool( self.collector_capabilities.get("supports_device_total", False) ), "supports_device_free": bool( self.collector_capabilities.get("supports_device_free", False) ), "sampling_source": str( self.collector_capabilities.get("sampling_source", "unknown") ), } metadata = dict(event.metadata or {}) metadata.update(capability_metadata) partial_fields = set(metadata.get("collector_partial_fields", []) or []) device_used = event.device_used if device_used is None: device_used = max(event.memory_allocated, event.memory_reserved) event_total = event.device_total if ( event_total is None and "device_total_bytes" not in partial_fields and self.total_memory ): event_total = self.total_memory legacy = { "session_id": session_id, "timestamp": event.timestamp, "event_type": event.event_type, "memory_allocated": event.memory_allocated, "memory_reserved": event.memory_reserved, "memory_change": event.memory_change, "allocator_active_bytes": event.active_memory, "allocator_inactive_bytes": event.inactive_memory, "device_used_bytes": device_used, "device_free_bytes": event.device_free, "device_total_bytes": event_total, "device_id": event.device_id, "context": event.context, "job_id": event.job_id, "rank": event.rank, "local_rank": event.local_rank, "world_size": event.world_size, "metadata": metadata, "total_memory": event_total, "pid": pid, "host": host, "collector": default_collector, "sampling_interval_ms": sampling_interval_ms, } telemetry_event = telemetry_event_from_record( legacy, default_collector=default_collector, default_sampling_interval_ms=sampling_interval_ms, default_session_id=session_id, ) return telemetry_event_to_dict(telemetry_event) def _append_to_telemetry_sink(self, event: TrackingEvent) -> None: if self._telemetry_sink is None: return try: self._telemetry_sink.append(self._telemetry_record_from_event(event)) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("append", exc) def _flush_telemetry_sink(self, *, force: bool = False) -> None: if self._telemetry_sink is None: return try: self._telemetry_sink.flush(force=force) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("flush", exc) def _close_telemetry_sink(self) -> None: if self._telemetry_sink is None: return try: self._close_sink_with_status( self._telemetry_sink, SESSION_STATUS_COMPLETED, ) self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics() except Exception as exc: self._disable_telemetry_sink("close", exc) else: self._telemetry_sink = None def _disable_telemetry_sink(self, operation: str, exc: Exception) -> None: sink = self._telemetry_sink if sink is None: return self._telemetry_sink = None logger.warning( "Disabling telemetry sink after %s failure: %s", operation, exc, ) if self._session_summary is not None: self._session_summary = update_session_summary( self._session_summary, status=SESSION_STATUS_INCOMPLETE, ended_at_ns=now_ns(), ) try: self._close_sink_with_status(sink, SESSION_STATUS_INCOMPLETE) if hasattr(sink, "get_diagnostics"): self._last_sink_diagnostics = sink.get_diagnostics() except Exception as close_exc: logger.debug( "Telemetry sink close failed after %s error: %s", operation, close_exc ) @staticmethod def _close_sink_with_status(sink: Any, status: str) -> None: try: sink.close(session_status=status) except TypeError: sink.close()
[docs] def handle_exception( self, exc: BaseException, context: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """Capture OOM diagnostics for recognized OOM exceptions.""" classification = classify_oom_exception(exc) if not classification.is_oom or classification.reason is None: return None if not self._oom_flight_recorder.config.enabled: return None dump_metadata: Dict[str, Any] = { "tracker_stats": self.get_statistics(), "collector_capabilities": dict(self.collector_capabilities), "total_memory_bytes": self.total_memory, "sampling_interval_s": self.sampling_interval, "session_id": None, "job_id": self.distributed_identity["job_id"], "rank": self.distributed_identity["rank"], "local_rank": self.distributed_identity["local_rank"], "world_size": self.distributed_identity["world_size"], } if metadata: dump_metadata.update(metadata) sample = self._safe_sample() dump_metadata.update( { "sample_allocated_bytes": sample.allocated_bytes, "sample_reserved_bytes": sample.reserved_bytes, "sample_used_bytes": sample.used_bytes, "sample_free_bytes": sample.free_bytes, "sample_total_bytes": sample.total_bytes, "sample_device_id": sample.device_id, } ) self._add_event( "error", 0, f"OOM detected ({classification.reason})", metadata={"oom_reason": classification.reason}, sample=sample, ) session_summary = getattr(self, "_session_summary", None) dump_metadata["tracker_stats"] = self.get_statistics() dump_metadata["session_id"] = ( session_summary.session_id if session_summary is not None else None ) try: dump_path = self._oom_flight_recorder.dump( reason=classification.reason, exception=exc, context=context, backend=self.backend, metadata=dump_metadata, session_summary=session_summary, ) except Exception as dump_exc: logger.debug("OOM flight recorder dump failed: %s", dump_exc) return None self.last_oom_dump_path = dump_path return dump_path
def _capture_native_history_dump(self, bundle_dir: Path) -> None: """Add CUDA allocator snapshot artifacts into an OOM dump bundle.""" try: files_written = capture_cuda_snapshot_artifacts( bundle_dir, device=self.device, history_recorded=True, ) except Exception as exc: logger.debug("CUDA native history dump failed: %s", exc) return manifest_path = bundle_dir / "manifest.json" metadata_path = bundle_dir / "metadata.json" try: manifest = json.loads(manifest_path.read_text(encoding="utf-8")) manifest_files = list(manifest.get("files", [])) for name in files_written: if name not in manifest_files: manifest_files.append(name) manifest["files"] = manifest_files manifest["native_history_enabled"] = True manifest["native_history_files"] = files_written manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8") except Exception as exc: logger.debug("Could not update OOM manifest with native history: %s", exc) try: metadata_payload = json.loads(metadata_path.read_text(encoding="utf-8")) custom_metadata = dict(metadata_payload.get("custom_metadata", {})) custom_metadata["native_history_enabled"] = True custom_metadata["native_history_files"] = files_written metadata_payload["custom_metadata"] = custom_metadata metadata_path.write_text( json.dumps(metadata_payload, indent=2), encoding="utf-8", ) except Exception as exc: logger.debug("Could not update OOM metadata with native history: %s", exc) try: root = Path(self._oom_flight_recorder.config.dump_dir) self._oom_flight_recorder._prune_retention(root) except Exception as exc: logger.debug( "Could not reapply OOM retention after native history: %s", exc )
[docs] @contextmanager def capture_oom( self, context: str = "runtime", metadata: Optional[Dict[str, Any]] = None, ) -> Any: """Capture OOM diagnostic bundle if a tracked block raises OOM.""" native_history_recorded = False if ( self.enable_native_cuda_history and self._oom_flight_recorder.config.enabled and self.backend == "cuda" and cuda_memory_history_supported() ): try: start_cuda_memory_history( device=self.device, trace_alloc_max_entries=self.native_history_max_entries, ) native_history_recorded = True except Exception as exc: logger.debug("Could not start CUDA native history recording: %s", exc) try: yield except Exception as exc: dump_path = self.handle_exception(exc, context=context, metadata=metadata) if dump_path and native_history_recorded and self.backend == "cuda": MemoryTracker._capture_native_history_dump(self, Path(dump_path)) if not Path(dump_path).exists(): self.last_oom_dump_path = None dump_path = None if dump_path: logger.error("OOM flight recorder dump saved to: %s", dump_path) raise finally: if native_history_recorded: try: stop_cuda_memory_history(device=self.device) except Exception as exc: logger.debug( "Could not stop CUDA native history recording: %s", exc, )
[docs] def add_alert_callback(self, callback: Callable[[TrackingEvent], None]) -> None: """Add a callback function to be called on alerts.""" self.alert_callbacks.append(callback)
[docs] def remove_alert_callback(self, callback: Callable[[TrackingEvent], None]) -> None: """Remove an alert callback.""" if callback in self.alert_callbacks: self.alert_callbacks.remove(callback)
[docs] def get_events( self, event_type: Optional[str] = None, last_n: Optional[int] = None, since: Optional[float] = None, ) -> List[TrackingEvent]: """ Get tracking events with optional filtering. Args: event_type: Filter by event type last_n: Get last N events since: Get events since timestamp Returns: List of filtered events """ events = list(self.events) # Filter by type if event_type: events = [e for e in events if e.event_type == event_type] # Filter by time if since: events = [e for e in events if e.timestamp >= since] # Limit results if last_n: events = events[-last_n:] return events
[docs] def get_memory_timeline(self, interval: float = 1.0) -> Dict[str, List]: """ Get memory usage timeline with specified interval. Args: interval: Time interval in seconds for aggregation Returns: Dictionary with timeline data """ if not self.events: return {"timestamps": [], "allocated": [], "reserved": []} # Group events by time intervals start_time = self.events[0].timestamp end_time = self.events[-1].timestamp timestamps = [] allocated_values = [] reserved_values = [] current_time = start_time while current_time <= end_time: # Find events in this interval interval_events = [ e for e in self.events if current_time <= e.timestamp < current_time + interval ] if interval_events: # Use the last event in the interval last_event = interval_events[-1] timestamps.append(current_time) allocated_values.append(last_event.memory_allocated) reserved_values.append(last_event.memory_reserved) current_time += interval return { "timestamps": timestamps, "allocated": allocated_values, "reserved": reserved_values, }
[docs] def get_statistics(self) -> Dict[str, Any]: """Get comprehensive tracking statistics.""" current_stats = self.stats.copy() recent_events = [e for e in self.events if e.timestamp > time.time() - 3600] sample = ( self._last_observed_sample if self._collector_health.status != COLLECTOR_HEALTH_UNHEALTHY else None ) current_stats.update( { "total_events": len(self.events), "events_last_hour": len(recent_events), "history_window_limit_events": self.max_events, "history_retained_events": len(self.events), "history_dropped_events": self._history_dropped_events, "backend": self.backend, "oom_flight_recorder_enabled": self._oom_flight_recorder.config.enabled, "last_oom_dump_path": self.last_oom_dump_path, "session_id": ( self._session_summary.session_id if self._session_summary is not None else None ), "session_status": ( self._session_summary.status if self._session_summary is not None else None ), "current_memory_allocated": ( sample.allocated_bytes if sample is not None else None ), "current_memory_reserved": ( sample.reserved_bytes if sample is not None else None ), "memory_utilization_percent": ( (sample.used_bytes / self.total_memory * 100) if sample is not None and self.total_memory > 0 else None ), "average_allocation_size": self.stats["total_allocation_bytes"] / max(self.stats["total_allocations"], 1), "average_deallocation_size": self.stats["total_deallocation_bytes"] / max(self.stats["total_deallocations"], 1), } ) current_stats.update(self._collector_health.to_dict()) current_stats.update(self._last_sink_diagnostics) if self.stats["tracking_start_time"]: tracking_duration = time.time() - self.stats["tracking_start_time"] current_stats.update( { "tracking_duration_seconds": tracking_duration, "allocations_per_second": self.stats["total_allocations"] / max(tracking_duration, 1), "bytes_allocated_per_second": self.stats["total_allocation_bytes"] / max(tracking_duration, 1), } ) return current_stats
[docs] def export_events(self, filename: str, format: str = "csv") -> None: """ Export tracking events to file. Args: filename: Output filename format: Export format ('csv' or 'json') """ import json import pandas as pd if not self.events: return # Convert events to canonical telemetry records. records = [self._telemetry_record_from_event(event) for event in self.events] if format == "csv": df = pd.DataFrame(records) df.to_csv(filename, index=False) elif format == "json": with open(filename, "w") as f: json.dump(records, f, indent=2, default=str) else: raise ValueError(f"Unsupported format: {format}")
[docs] def clear_events(self) -> None: """Clear all tracking events.""" self.events.clear() self._history_dropped_events = 0 # Reset statistics self.stats.update( { "peak_memory": 0, "total_allocations": 0, "total_deallocations": 0, "total_allocation_bytes": 0, "total_deallocation_bytes": 0, "alert_count": 0, } )
[docs] def set_threshold(self, threshold_name: str, value: Union[int, float]) -> None: """ Set alert threshold. Args: threshold_name: Name of the threshold value: Threshold value """ if threshold_name in self.thresholds: self.thresholds[threshold_name] = value else: raise ValueError(f"Unknown threshold: {threshold_name}")
[docs] def get_alerts(self, last_n: Optional[int] = None) -> List[TrackingEvent]: """Get all alert events (warnings, critical, errors).""" alert_types = ["warning", "critical", "error"] alerts = [e for e in self.events if e.event_type in alert_types] if last_n: alerts = alerts[-last_n:] return alerts
def __enter__(self) -> "MemoryTracker": """Context manager entry.""" self.start_tracking() return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit.""" self.stop_tracking()
[docs] class MemoryWatchdog: """Memory watchdog for automated memory management.""" def __init__( self, tracker: MemoryTracker, auto_cleanup: bool = True, cleanup_threshold: float = 0.9, aggressive_cleanup_threshold: float = 0.95, ): """ Initialize memory watchdog. Args: tracker: MemoryTracker instance to monitor auto_cleanup: Whether to automatically clean up memory cleanup_threshold: Memory usage threshold to trigger cleanup aggressive_cleanup_threshold: Threshold for aggressive cleanup """ self.tracker = tracker self.auto_cleanup = auto_cleanup self.cleanup_threshold = cleanup_threshold self.aggressive_cleanup_threshold = aggressive_cleanup_threshold # Register alert callback self.tracker.add_alert_callback(self._handle_alert) self.cleanup_count = 0 self.last_cleanup_time = 0.0 self.min_cleanup_interval = 30.0 # Minimum 30 seconds between cleanups def _handle_alert(self, event: TrackingEvent) -> None: """Handle memory alerts.""" if not self.auto_cleanup: return current_time = time.time() # Avoid too frequent cleanups if current_time - self.last_cleanup_time < self.min_cleanup_interval: return # Check if cleanup is needed if event.event_type == "critical" or ( event.event_type == "warning" and event.metadata and event.metadata.get("usage_percent", 0) >= self.cleanup_threshold * 100 ): self._perform_cleanup(aggressive=event.event_type == "critical") self.last_cleanup_time = current_time def _perform_cleanup(self, aggressive: bool = False) -> None: """Perform memory cleanup.""" self.cleanup_count += 1 try: backend = self.tracker.backend if backend in {"cuda", "rocm"}: torch.cuda.empty_cache() if aggressive: import gc gc.collect() torch.cuda.synchronize() torch.cuda.empty_cache() elif backend == "mps": import gc import torch.mps as torch_mps if hasattr(torch_mps, "empty_cache"): torch_mps.empty_cache() if aggressive: gc.collect() if hasattr(torch_mps, "empty_cache"): torch_mps.empty_cache() elif aggressive: import gc gc.collect() # Log cleanup event cleanup_type = "aggressive" if aggressive else "standard" self.tracker._add_event( "cleanup", 0, f"Performed {cleanup_type} memory cleanup" ) except Exception as e: self.tracker._add_event("error", 0, f"Cleanup failed: {str(e)}")
[docs] def force_cleanup(self, aggressive: bool = False) -> None: """Force immediate memory cleanup.""" self._perform_cleanup(aggressive)
[docs] def get_cleanup_stats(self) -> Dict[str, Any]: """Get cleanup statistics.""" return { "cleanup_count": self.cleanup_count, "last_cleanup_time": self.last_cleanup_time, "auto_cleanup_enabled": self.auto_cleanup, "cleanup_threshold": self.cleanup_threshold, "aggressive_cleanup_threshold": self.aggressive_cleanup_threshold, }