"""
Real-time TensorFlow Memory Tracking
This module provides real-time monitoring of GPU memory usage during TensorFlow
model training and inference, with configurable alerts and automatic cleanup.
"""
import logging
import os
import socket
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from .tf_env import configure_tensorflow_logging
configure_tensorflow_logging()
try:
import tensorflow as tf
TF_AVAILABLE = True
except ImportError:
TF_AVAILABLE = False
tf = None
from stormlog.collector_health import (
COLLECTOR_HEALTH_HEALTHY,
COLLECTOR_HEALTH_UNHEALTHY,
CollectorHealthState,
collector_retry_delay_seconds,
)
from stormlog.phases import PhaseHandle, PhaseRecorder, PhaseToken
from stormlog.session import (
SESSION_STATUS_COMPLETED,
SESSION_STATUS_INCOMPLETE,
SESSION_STATUS_RUNNING,
SessionSummary,
create_session_summary,
finalize_session_summary,
now_ns,
update_session_summary,
)
from stormlog.telemetry import (
resolve_distributed_identity,
telemetry_event_from_record,
telemetry_event_to_dict,
)
from stormlog.telemetry_sink import AppendOnlyTelemetrySink, TelemetrySinkConfig
[docs]
@dataclass
class TrackingResult:
"""Results from real-time memory tracking."""
start_time: float
end_time: float
memory_usage: List[float] = field(default_factory=list)
timestamps: List[float] = field(default_factory=list)
events: List[Dict] = field(default_factory=list)
alerts_triggered: List[Dict] = field(default_factory=list)
peak_memory: float = 0.0
average_memory: float = 0.0
min_memory: float = float("inf")
history_window_limit: int = 0
history_retained_samples: int = 0
history_dropped_samples: int = 0
history_retained_events: int = 0
history_dropped_events: int = 0
history_retained_alerts: int = 0
history_dropped_alerts: int = 0
@property
def duration(self) -> float:
"""Total tracking duration."""
return self.end_time - self.start_time
@property
def memory_growth_rate(self) -> float:
"""Memory growth rate in MB/second."""
if len(self.memory_usage) < 2 or self.duration <= 0:
return 0.0
return (self.memory_usage[-1] - self.memory_usage[0]) / self.duration
@dataclass(frozen=True)
class _TrackingResultData:
retained_memory_usage: List[float]
retained_timestamps: List[float]
retained_events: List[Dict[str, Any]]
retained_alerts: List[Dict[str, Any]]
total_samples_observed: int
peak_memory: float
min_memory: float
sum_memory: float
dropped_samples: int
dropped_events: int
dropped_alerts: int
[docs]
class MemoryTracker:
"""Real-time TensorFlow GPU memory tracker."""
def __init__(
self,
sampling_interval: float = 1.0,
alert_threshold_mb: Optional[float] = None,
device: Optional[str] = None,
enable_logging: bool = True,
max_history: int = 10_000,
job_id: Optional[str] = None,
rank: Optional[int] = None,
local_rank: Optional[int] = None,
world_size: Optional[int] = None,
telemetry_sink_config: Optional[TelemetrySinkConfig] = None,
):
"""
Initialize memory tracker.
Args:
sampling_interval: Time between memory samples in seconds
alert_threshold_mb: Memory threshold for alerts in MB
device: TensorFlow device to monitor (e.g., '/GPU:0')
enable_logging: Whether to log events
"""
if not TF_AVAILABLE:
raise ImportError("TensorFlow not available. Please install TensorFlow.")
if sampling_interval <= 0:
raise ValueError("sampling_interval must be > 0")
if max_history <= 0:
raise ValueError("max_history must be >= 1")
self.sampling_interval = sampling_interval
self.alert_threshold_mb = alert_threshold_mb
self.device = device or self._get_default_device()
self.enable_logging = enable_logging
self.max_history = max_history
self._telemetry_sink_config = telemetry_sink_config
self._telemetry_sink = (
AppendOnlyTelemetrySink(telemetry_sink_config)
if telemetry_sink_config is not None
else None
)
self.distributed_identity = resolve_distributed_identity(
job_id=job_id,
rank=rank,
local_rank=local_rank,
world_size=world_size,
env=os.environ,
)
self.session_source = "stormlog.tensorflow.tracker"
self._session_summary: Optional[SessionSummary] = None
# Tracking state
self.tracking = False
self.tracking_thread: Optional[threading.Thread] = None
self._memory_usage: deque[float] = deque(maxlen=max_history)
self._timestamps: deque[float] = deque(maxlen=max_history)
self._events: deque[Dict[str, Any]] = deque(maxlen=max_history)
self._alerts: deque[Dict[str, Any]] = deque(maxlen=max_history)
self._history_dropped_samples = 0
self._history_dropped_events = 0
self._history_dropped_alerts = 0
self._collector_failure_event_count = 0
self._total_samples_observed = 0
self._peak_memory_mb = 0.0
self._min_memory_mb = float("inf")
self._sum_memory_mb = 0.0
self._last_sink_diagnostics: Dict[str, int] = self._empty_sink_diagnostics()
self._phase_state = PhaseRecorder()
# Thread synchronization
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._collector_health = CollectorHealthState()
self._last_successful_memory_mb: Optional[float] = None
self._session_start_time: Optional[float] = None
self._session_end_time: Optional[float] = None
self._collector_retry_backoff_initial_s = 1.0
self._collector_retry_backoff_factor = 2.0
self._collector_retry_backoff_cap_s = 30.0
# Alert callbacks
self.alert_callbacks: List[Callable[[Dict[str, Any]], None]] = []
if enable_logging:
logging.info(f"TensorFlow Memory Tracker initialized for {self.device}")
@staticmethod
def _empty_sink_diagnostics() -> Dict[str, int]:
return {
"rollover_count": 0,
"pruned_segment_count": 0,
"pruned_bytes": 0,
"final_retained_files": 0,
"final_retained_bytes": 0,
}
@property
def memory_usage(self) -> List[float]:
return list(self._memory_usage)
@property
def timestamps(self) -> List[float]:
return list(self._timestamps)
@property
def events(self) -> List[Dict[str, Any]]:
return list(self._events)
@property
def alerts(self) -> List[Dict[str, Any]]:
return list(self._alerts)
def _reset_history(self) -> None:
self._memory_usage.clear()
self._timestamps.clear()
self._events.clear()
self._alerts.clear()
self._history_dropped_samples = 0
self._history_dropped_events = 0
self._history_dropped_alerts = 0
self._collector_failure_event_count = 0
self._total_samples_observed = 0
self._peak_memory_mb = 0.0
self._min_memory_mb = float("inf")
self._sum_memory_mb = 0.0
self._last_sink_diagnostics = self._empty_sink_diagnostics()
def _tracking_result_data(self) -> _TrackingResultData:
return _TrackingResultData(
retained_memory_usage=list(self._memory_usage),
retained_timestamps=list(self._timestamps),
retained_events=list(self._events),
retained_alerts=list(self._alerts),
total_samples_observed=self._total_samples_observed,
peak_memory=self._peak_memory_mb,
min_memory=self._min_memory_mb,
sum_memory=self._sum_memory_mb,
dropped_samples=self._history_dropped_samples,
dropped_events=self._history_dropped_events,
dropped_alerts=self._history_dropped_alerts,
)
def _append_sample_locked(self, memory_mb: float, timestamp: float) -> None:
if len(self._memory_usage) == self.max_history:
self._history_dropped_samples += 1
self._memory_usage.append(memory_mb)
self._timestamps.append(timestamp)
self._total_samples_observed += 1
self._peak_memory_mb = max(self._peak_memory_mb, memory_mb)
self._min_memory_mb = min(self._min_memory_mb, memory_mb)
self._sum_memory_mb += memory_mb
def _append_event_locked(self, record: Dict[str, Any]) -> None:
if len(self._events) == self.max_history:
self._history_dropped_events += 1
self._events.append(record)
def _append_alert_locked(self, alert: Dict[str, Any]) -> None:
if len(self._alerts) == self.max_history:
self._history_dropped_alerts += 1
self._alerts.append(alert)
def _ensure_session_summary(self) -> SessionSummary:
"""Create the active tracking session summary if needed."""
if self._session_summary is None:
summary = create_session_summary(
source=self.session_source,
status=SESSION_STATUS_RUNNING,
started_at_ns=now_ns(),
host=socket.gethostname(),
pid=os.getpid(),
job_id=self.distributed_identity["job_id"],
rank=self.distributed_identity["rank"],
local_rank=self.distributed_identity["local_rank"],
world_size=self.distributed_identity["world_size"],
)
if self._telemetry_sink is not None and hasattr(
self._telemetry_sink, "start_session"
):
summary = self._telemetry_sink.start_session(summary)
self._session_summary = summary
return self._session_summary
[docs]
def get_session_summary(self) -> Optional[SessionSummary]:
"""Return the active or most recent TensorFlow tracking session."""
return self._session_summary
def _ensure_telemetry_sink(self) -> None:
if self._telemetry_sink is None and self._telemetry_sink_config is not None:
self._telemetry_sink = AppendOnlyTelemetrySink(self._telemetry_sink_config)
def _device_id(self) -> int:
"""Best-effort device id extraction."""
if isinstance(self.device, str):
if "CPU" in self.device.upper():
return -1
if ":" in self.device:
tail = self.device.rsplit(":", 1)[-1]
if tail.isdigit():
return int(tail)
if "/GPU" in self.device.upper():
return 0
return -1
def _build_telemetry_event_record(
self,
*,
timestamp: float,
memory_mb: float,
event_type: str = "sample",
context: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
sampling_interval_ms = int(round(self.sampling_interval * 1000))
legacy = {
"session_id": self._ensure_session_summary().session_id,
"timestamp": timestamp,
"type": event_type,
"memory_mb": memory_mb,
"device_id": self._device_id(),
"context": context,
"metadata": {
**dict(metadata or {}),
**self._collector_health.to_dict(),
},
"collector": "stormlog.tensorflow.memory_tracker",
"sampling_interval_ms": sampling_interval_ms,
"pid": os.getpid(),
"host": socket.gethostname(),
"job_id": self.distributed_identity["job_id"],
"rank": self.distributed_identity["rank"],
"local_rank": self.distributed_identity["local_rank"],
"world_size": self.distributed_identity["world_size"],
}
event = telemetry_event_from_record(
legacy,
default_collector="stormlog.tensorflow.memory_tracker",
default_sampling_interval_ms=sampling_interval_ms,
default_session_id=self._ensure_session_summary().session_id,
)
return telemetry_event_to_dict(event)
def _get_default_device(self) -> str:
"""Get default TensorFlow device."""
try:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
return "/GPU:0"
else:
return "/CPU:0"
except Exception as exc:
logging.debug("Default device detection failed: %s", exc)
return "/CPU:0"
def _get_current_memory(self) -> float:
"""Get current memory usage in MB."""
if "/GPU:" in self.device:
# Extract GPU index from device string
gpu_id = int(self.device.split(":")[1]) if ":" in self.device else 0
memory_info = tf.config.experimental.get_memory_info(f"/GPU:{gpu_id}")
current_bytes = memory_info.get("current", 0)
if isinstance(current_bytes, (int, float)):
return float(current_bytes) / (1024 * 1024)
raise RuntimeError("TensorFlow memory info returned a non-numeric value")
# CPU memory tracking
import psutil
process = psutil.Process()
return float(process.memory_info().rss) / (1024 * 1024)
def _set_collector_health(
self,
*,
status: str,
telemetry_partial: bool,
last_error: Optional[str] = None,
consecutive_failures: int = 0,
next_retry_epoch_s: Optional[float] = None,
) -> None:
self._collector_health = CollectorHealthState(
status=status,
telemetry_partial=telemetry_partial,
last_error=last_error,
consecutive_failures=consecutive_failures,
next_retry_epoch_s=next_retry_epoch_s,
)
def _retry_collection_due(self, now: float) -> bool:
retry_at = self._collector_health.next_retry_epoch_s
return retry_at is None or now >= retry_at
def _status_memory_value(self) -> float:
return float(self._last_successful_memory_mb or 0.0)
def _append_event(
self,
*,
timestamp: float,
memory_mb: float,
event_type: str,
context: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
record = self._build_telemetry_event_record(
timestamp=timestamp,
memory_mb=memory_mb,
event_type=event_type,
context=context,
metadata=metadata,
)
with self._lock:
self._append_event_locked(record)
self._append_to_telemetry_sink(record)
def _transition_to_failure(self, timestamp: float, exc: BaseException) -> None:
previous_health = self._collector_health
consecutive_failures = previous_health.consecutive_failures + 1
retry_delay_s = collector_retry_delay_seconds(
consecutive_failures,
initial_delay_s=self._collector_retry_backoff_initial_s,
factor=self._collector_retry_backoff_factor,
max_delay_s=self._collector_retry_backoff_cap_s,
)
next_retry_epoch_s = timestamp + retry_delay_s if retry_delay_s > 0 else None
error_message = str(exc)
self._set_collector_health(
status=COLLECTOR_HEALTH_UNHEALTHY,
telemetry_partial=True,
last_error=error_message,
consecutive_failures=consecutive_failures,
next_retry_epoch_s=next_retry_epoch_s,
)
if previous_health.status == COLLECTOR_HEALTH_HEALTHY:
self._collector_failure_event_count += 1
self._append_event(
timestamp=timestamp,
memory_mb=self._status_memory_value(),
event_type="collector_degraded",
context="Collector unavailable; telemetry paused until recovery.",
metadata={
"collector_transition": "degraded",
"collector_degraded_from": previous_health.status,
"collector_degradation_reason": error_message,
"collector_retry_delay_s": retry_delay_s,
},
)
if self.enable_logging:
logging.warning("Could not get memory usage: %s", error_message)
def _transition_to_success(self, timestamp: float) -> None:
previous_health = self._collector_health
previous_error = previous_health.last_error
previous_failures = previous_health.consecutive_failures
if previous_health.status != COLLECTOR_HEALTH_HEALTHY:
self._set_collector_health(
status=COLLECTOR_HEALTH_HEALTHY,
telemetry_partial=False,
)
self._collector_failure_event_count += 1
self._append_event(
timestamp=timestamp,
memory_mb=self._status_memory_value(),
event_type="collector_recovered",
context="Collector recovered; full telemetry sampling resumed.",
metadata={
"collector_transition": "recovered",
"collector_recovered_from": previous_health.status,
"collector_previous_error": previous_error,
"collector_previous_failure_count": previous_failures,
},
)
return
self._set_collector_health(
status=COLLECTOR_HEALTH_HEALTHY,
telemetry_partial=False,
)
def _run_tracking_iteration(self) -> None:
"""Collect one tracking sample or advance degraded-mode state."""
current_time = time.time()
if not self._retry_collection_due(current_time):
return
try:
current_memory = self._get_current_memory()
except Exception as exc:
self._transition_to_failure(current_time, exc)
return
self._last_successful_memory_mb = current_memory
self._transition_to_success(current_time)
with self._lock:
self._append_sample_locked(current_memory, current_time)
self._append_event(
timestamp=current_time,
memory_mb=current_memory,
event_type="sample",
)
if self.alert_threshold_mb and current_memory > self.alert_threshold_mb:
self._trigger_alert(current_memory, current_time)
def _tracking_loop(self) -> None:
"""Main tracking loop running in background thread."""
while not self._stop_event.is_set():
try:
self._run_tracking_iteration()
self._flush_telemetry_sink()
self._stop_event.wait(self.sampling_interval)
except Exception as e:
if self.enable_logging:
logging.error(f"Error in tracking loop: {e}")
self._flush_telemetry_sink(force=True)
self._stop_event.wait(self.sampling_interval)
def _trigger_alert(self, memory_mb: float, timestamp: float) -> None:
"""Trigger memory usage alert."""
alert = {
"timestamp": timestamp,
"memory_mb": memory_mb,
"threshold_mb": self.alert_threshold_mb,
"message": f"Memory usage {memory_mb:.1f} MB exceeds threshold {self.alert_threshold_mb:.1f} MB",
}
with self._lock:
self._append_alert_locked(alert)
# Log alert
if self.enable_logging:
logging.warning(alert["message"])
# Call alert callbacks
for callback in self.alert_callbacks:
try:
callback(alert)
except Exception as e:
if self.enable_logging:
logging.error(f"Error in alert callback: {e}")
[docs]
def add_alert_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None:
"""Add callback function for memory alerts."""
self.alert_callbacks.append(callback)
[docs]
def start_tracking(self) -> None:
"""Start real-time memory tracking."""
if self.tracking:
if self.enable_logging:
logging.warning("Tracking already started")
return
self._session_start_time = time.time()
self._session_end_time = None
self._session_summary = None
self._phase_state.reset()
self._ensure_telemetry_sink()
self._stop_event.clear()
# Reset tracking data
with self._lock:
self._reset_history()
self._last_successful_memory_mb = None
self._set_collector_health(
status=COLLECTOR_HEALTH_HEALTHY,
telemetry_partial=False,
)
self._ensure_session_summary()
# Start tracking thread
self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True)
self.tracking_thread.start()
self.tracking = True
self._append_event(
timestamp=self._session_start_time,
memory_mb=self._status_memory_value(),
event_type="start",
context="Memory tracking started",
)
if self.enable_logging:
logging.info(
f"Started memory tracking with {self.sampling_interval}s interval"
)
[docs]
def stop_tracking(self) -> TrackingResult:
"""Stop tracking and return results."""
if not self.tracking:
if self.enable_logging:
logging.warning("Tracking not started")
return self._create_empty_result()
self.tracking = False
self._stop_event.set()
# Wait for tracking thread to finish
if self.tracking_thread:
self.tracking_thread.join(timeout=5.0)
self._session_end_time = time.time()
self._append_event(
timestamp=self._session_end_time,
memory_mb=self._status_memory_value(),
event_type="stop",
context="Memory tracking stopped",
)
self._close_telemetry_sink()
if self._session_summary is not None:
self._session_summary = finalize_session_summary(
self._session_summary,
ended_at_ns=now_ns(),
)
self._phase_state.reset()
# Create result
result = self._create_tracking_result()
if self.enable_logging:
logging.info(
f"Stopped memory tracking. Peak usage: {result.peak_memory:.1f} MB"
)
return result
def _create_tracking_result(self) -> TrackingResult:
"""Create tracking result from collected data."""
with self._lock:
result_data = self._tracking_result_data()
if (
not result_data.retained_memory_usage
and not result_data.retained_events
and not result_data.retained_alerts
):
return self._create_empty_result()
session_start = self._session_start_time
session_end = self._session_end_time
if session_start is None:
session_start = (
result_data.retained_timestamps[0]
if result_data.retained_timestamps
else time.time()
)
if session_end is None:
session_end = (
result_data.retained_timestamps[-1]
if result_data.retained_timestamps
else time.time()
)
return TrackingResult(
start_time=session_start,
end_time=session_end,
memory_usage=result_data.retained_memory_usage,
timestamps=result_data.retained_timestamps,
events=result_data.retained_events,
alerts_triggered=result_data.retained_alerts,
peak_memory=(
result_data.peak_memory
if result_data.total_samples_observed
else 0.0
),
average_memory=(
result_data.sum_memory / result_data.total_samples_observed
if result_data.total_samples_observed
else 0.0
),
min_memory=(
result_data.min_memory
if result_data.total_samples_observed
else 0.0
),
history_window_limit=self.max_history,
history_retained_samples=len(result_data.retained_memory_usage),
history_dropped_samples=result_data.dropped_samples,
history_retained_events=len(result_data.retained_events),
history_dropped_events=result_data.dropped_events,
history_retained_alerts=len(result_data.retained_alerts),
history_dropped_alerts=result_data.dropped_alerts,
)
def _create_empty_result(self) -> TrackingResult:
"""Create empty tracking result."""
current_time = time.time()
start_time = self._session_start_time or current_time
end_time = self._session_end_time or start_time
return TrackingResult(
start_time=start_time,
end_time=end_time,
memory_usage=[],
timestamps=[],
events=[],
alerts_triggered=[],
peak_memory=0.0,
average_memory=0.0,
min_memory=0.0,
history_window_limit=self.max_history,
history_retained_samples=0,
history_dropped_samples=0,
history_retained_events=0,
history_dropped_events=0,
history_retained_alerts=0,
history_dropped_alerts=0,
)
[docs]
def get_current_memory(self) -> float:
"""Get current memory usage."""
try:
return self._get_current_memory()
except Exception:
return float(self._last_successful_memory_mb or 0.0)
[docs]
def get_statistics(self) -> dict[str, Any]:
"""Return current tracker health and latest successful memory sample."""
with self._lock:
retained_events = len(self._events)
retained_samples = len(self._memory_usage)
retained_alerts = len(self._alerts)
peak_memory = self._peak_memory_mb if self._total_samples_observed else 0.0
average_memory = (
self._sum_memory_mb / self._total_samples_observed
if self._total_samples_observed
else 0.0
)
min_memory = self._min_memory_mb if self._total_samples_observed else 0.0
collector_failure_event_count = self._collector_failure_event_count
dropped_samples = self._history_dropped_samples
dropped_events = self._history_dropped_events
dropped_alerts = self._history_dropped_alerts
tracking_start = self._session_start_time
tracking_end = self._session_end_time
tracking_duration = (
(tracking_end or time.time()) - tracking_start
if isinstance(tracking_start, (int, float))
else 0.0
)
current_memory_mb = (
self._last_successful_memory_mb
if self._collector_health.status == COLLECTOR_HEALTH_HEALTHY
else None
)
return {
"current_memory_mb": current_memory_mb,
"peak_memory_mb": peak_memory,
"average_memory_mb": average_memory,
"min_memory_mb": min_memory,
"collector_failure_event_count": collector_failure_event_count,
"total_events": retained_events,
"tracking_duration_seconds": tracking_duration,
"history_window_limit": self.max_history,
"history_retained_samples": retained_samples,
"history_dropped_samples": dropped_samples,
"history_retained_events": retained_events,
"history_dropped_events": dropped_events,
"history_retained_alerts": retained_alerts,
"history_dropped_alerts": dropped_alerts,
**self._last_sink_diagnostics,
"session_id": (
self._session_summary.session_id
if self._session_summary is not None
else None
),
"session_status": (
self._session_summary.status
if self._session_summary is not None
else None
),
**self._collector_health.to_dict(),
}
def _append_to_telemetry_sink(self, record: Dict[str, Any]) -> None:
if self._telemetry_sink is None:
return
try:
self._telemetry_sink.append(record)
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("append", exc)
def _flush_telemetry_sink(self, *, force: bool = False) -> None:
if self._telemetry_sink is None:
return
try:
self._telemetry_sink.flush(force=force)
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("flush", exc)
def _close_telemetry_sink(self) -> None:
if self._telemetry_sink is None:
return
try:
self._close_sink_with_status(
self._telemetry_sink,
SESSION_STATUS_COMPLETED,
)
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("close", exc)
else:
self._telemetry_sink = None
def _disable_telemetry_sink(self, operation: str, exc: Exception) -> None:
sink = self._telemetry_sink
if sink is None:
return
self._telemetry_sink = None
logging.warning(
"Disabling TensorFlow telemetry sink after %s failure: %s",
operation,
exc,
)
if self._session_summary is not None:
self._session_summary = update_session_summary(
self._session_summary,
status=SESSION_STATUS_INCOMPLETE,
ended_at_ns=now_ns(),
)
try:
self._close_sink_with_status(sink, SESSION_STATUS_INCOMPLETE)
if hasattr(sink, "get_diagnostics"):
self._last_sink_diagnostics = sink.get_diagnostics()
except Exception as close_exc:
logging.debug(
"TensorFlow telemetry sink close failed after %s error: %s",
operation,
close_exc,
)
@staticmethod
def _close_sink_with_status(sink: Any, status: str) -> None:
try:
sink.close(session_status=status)
except TypeError:
sink.close()
[docs]
def enter_phase(
self, name: str, *, metadata: Optional[Dict[str, Any]] = None
) -> PhaseHandle:
"""Enter one structured TensorFlow tracking phase."""
if not self.tracking:
raise RuntimeError("Tracking must be active before entering a phase.")
session = self._ensure_session_summary()
token, boundary = self._phase_state.enter(
session_id=session.session_id,
rank=self.distributed_identity["rank"],
name=name,
attrs=metadata,
)
self._append_event(
timestamp=time.time(),
memory_mb=self._status_memory_value(),
event_type=boundary.event_type,
context=boundary.context,
metadata=boundary.metadata,
)
return PhaseHandle(
scope_id=boundary.scope_id,
name=name,
path=boundary.path,
close_callback=lambda: self._emit_phase_exit(token),
)
[docs]
@contextmanager
def phase(self, name: str, *, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Context manager that emits structured TensorFlow phase telemetry."""
handle = self.enter_phase(name, metadata=metadata)
try:
yield handle
finally:
handle.close()
def _emit_phase_exit(self, token: PhaseToken) -> None:
boundary = self._phase_state.exit(token)
self._append_event(
timestamp=time.time(),
memory_mb=self._status_memory_value(),
event_type=boundary.event_type,
context=boundary.context,
metadata=boundary.metadata,
)
[docs]
def set_alert_threshold(self, threshold_mb: float) -> None:
"""Update alert threshold."""
self.alert_threshold_mb = threshold_mb
if self.enable_logging:
logging.info(f"Updated alert threshold to {threshold_mb} MB")
[docs]
def check_alerts(self) -> bool:
"""Check if any alerts have been triggered recently."""
with self._lock:
# Check for alerts in the last 10 seconds
recent_alerts = [
alert
for alert in self._alerts
if time.time() - alert["timestamp"] < 10.0
]
return len(recent_alerts) > 0
[docs]
def get_tracking_results(self) -> TrackingResult:
"""Get current tracking results without stopping."""
return self._create_tracking_result()
[docs]
class MemoryWatchdog:
"""Automatic memory management and cleanup for TensorFlow."""
def __init__(
self,
max_memory_mb: float = 8000,
cleanup_threshold_mb: float = 6000,
check_interval: float = 5.0,
):
"""
Initialize memory watchdog.
Args:
max_memory_mb: Maximum memory before forced cleanup
cleanup_threshold_mb: Memory threshold to trigger cleanup
check_interval: Time between memory checks in seconds
"""
if not TF_AVAILABLE:
raise ImportError("TensorFlow not available.")
self.max_memory_mb = max_memory_mb
self.cleanup_threshold_mb = cleanup_threshold_mb
self.check_interval = check_interval
self.active = False
self.watchdog_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
# Cleanup callbacks
self.cleanup_callbacks: List[Callable[[], None]] = []
logging.info(f"Memory Watchdog initialized with {max_memory_mb} MB limit")
[docs]
def add_cleanup_callback(self, callback: Callable[[], None]) -> None:
"""Add cleanup callback function."""
self.cleanup_callbacks.append(callback)
def _get_memory_usage(self) -> float:
"""Get current GPU memory usage."""
try:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
memory_info = tf.config.experimental.get_memory_info("/GPU:0")
current_bytes = memory_info.get("current", 0)
if isinstance(current_bytes, (int, float)):
return float(current_bytes) / (1024 * 1024)
return 0.0
return 0.0
except Exception as exc:
logging.debug("Watchdog could not get GPU memory usage: %s", exc)
return 0.0
def _cleanup_memory(self) -> None:
"""Perform memory cleanup."""
try:
# Clear TensorFlow session
tf.keras.backend.clear_session()
# Force garbage collection
import gc
gc.collect()
# Call custom cleanup callbacks
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logging.error(f"Error in cleanup callback: {e}")
logging.info("Performed memory cleanup")
except Exception as e:
logging.error(f"Error during memory cleanup: {e}")
def _watchdog_loop(self) -> None:
"""Main watchdog loop."""
while not self._stop_event.is_set():
try:
current_memory = self._get_memory_usage()
if current_memory > self.max_memory_mb:
logging.warning(
f"Memory usage {current_memory:.1f} MB exceeds limit {self.max_memory_mb} MB - forcing cleanup"
)
self._cleanup_memory()
elif current_memory > self.cleanup_threshold_mb:
logging.info(
f"Memory usage {current_memory:.1f} MB above threshold {self.cleanup_threshold_mb} MB - performing cleanup"
)
self._cleanup_memory()
# Wait for next check
self._stop_event.wait(self.check_interval)
except Exception as e:
logging.error(f"Error in watchdog loop: {e}")
break
[docs]
def start(self) -> None:
"""Start memory watchdog."""
if self.active:
logging.warning("Watchdog already active")
return
self.active = True
self._stop_event.clear()
self.watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True)
self.watchdog_thread.start()
logging.info("Started memory watchdog")
[docs]
def stop(self) -> None:
"""Stop memory watchdog."""
if not self.active:
return
self.active = False
self._stop_event.set()
if self.watchdog_thread:
self.watchdog_thread.join(timeout=5.0)
logging.info("Stopped memory watchdog")
[docs]
def force_cleanup(self) -> None:
"""Force immediate memory cleanup."""
self._cleanup_memory()