Source code for stormlog._wandb.tracking

"""Tracking-session W&B export helpers."""

from __future__ import annotations

import math
from pathlib import Path
from typing import Any, Mapping, Sequence

from ..session import SessionSummary
from .attribution import log_attribution_outputs
from .core import (
    WandbExportConfig,
    coerce_existing_dir,
    coerce_existing_file,
    log_directory_artifact,
    log_file_artifact,
    materialize_html_file,
    resolve_run,
    session_slug,
    session_summary_fields,
    update_summary,
)
from .dashboard import tracking_dashboard_html

_ALERT_EVENT_TYPES = frozenset({"warning", "critical", "error", "peak"})
_TIMELINE_MAX_POINTS = 250


[docs] def export_tracking_run_to_wandb( config: WandbExportConfig, *, command_name: str, session_summary: SessionSummary | None, stats: Mapping[str, Any], events: Sequence[Any], output_path: str | Path | None = None, telemetry_sink_dir: str | Path | None = None, oom_dump_path: str | Path | None = None, attribution_bundle_dir: str | Path | None = None, ) -> None: """Export one completed tracking session to W&B.""" if not config.enabled: return timeline_rows = tracking_timeline_rows(events) wandb, run, managed = resolve_run( config, command_name=command_name, session_summary=session_summary, ) safe_session = session_slug(session_summary) output_file = coerce_existing_file(output_path) sink_dir = coerce_existing_dir(telemetry_sink_dir) oom_dir = coerce_existing_dir(oom_dump_path) attribution_dir = coerce_existing_dir(attribution_bundle_dir) or oom_dir try: update_summary( run, tracking_metrics(stats) | {"stormlog_chart_point_count": len(timeline_rows)} | session_summary_fields(session_summary) | tracking_summary_fields(stats, output_path=output_path), ) log_tracking_time_series(run, timeline_rows) if config.log_tables: log_alerts_table(wandb, run, events) update_summary( run, log_tracking_visualizations( wandb, run, timeline_rows, session_slug=safe_session, dashboard_root=_tracking_dashboard_root(output_file, sink_dir), allow_artifact_logging=config.log_artifacts, ), ) if config.log_artifacts: if output_file is not None: log_file_artifact( wandb, run, artifact_name=f"stormlog-track-output-{safe_session}", artifact_type="stormlog-track-output", path=output_file, ) if sink_dir is not None: log_directory_artifact( wandb, run, artifact_name=f"stormlog-telemetry-sink-{safe_session}", artifact_type="stormlog-telemetry-sink", path=sink_dir, ) if oom_dir is not None: log_directory_artifact( wandb, run, artifact_name=f"stormlog-oom-dump-{safe_session}", artifact_type="stormlog-oom-dump", path=oom_dir, ) if ( attribution_bundle_dir is not None and attribution_dir is not None and attribution_dir != oom_dir ): log_directory_artifact( wandb, run, artifact_name=f"stormlog-attribution-bundle-{safe_session}", artifact_type="stormlog-attribution-bundle", path=attribution_dir, ) if config.log_attribution and attribution_dir is not None: update_summary( run, log_attribution_outputs( wandb, run, root=attribution_dir, session_slug=safe_session, allow_artifact_logging=config.log_artifacts, ), ) finally: if managed: run.finish()
def tracking_metrics(stats: Mapping[str, Any]) -> dict[str, Any]: metric_names = { "stormlog_peak_memory_bytes": "peak_memory", "stormlog_total_events": "total_events", "stormlog_alert_count": "alert_count", "stormlog_current_memory_allocated_bytes": "current_memory_allocated", "stormlog_current_memory_reserved_bytes": "current_memory_reserved", "stormlog_memory_utilization_percent": "memory_utilization_percent", "stormlog_total_allocations": "total_allocations", "stormlog_total_deallocations": "total_deallocations", "stormlog_total_allocation_bytes": "total_allocation_bytes", "stormlog_total_deallocation_bytes": "total_deallocation_bytes", "stormlog_tracking_duration_seconds": "tracking_duration_seconds", "stormlog_allocations_per_second": "allocations_per_second", "stormlog_bytes_allocated_per_second": "bytes_allocated_per_second", "stormlog_history_retained_events": "history_retained_events", "stormlog_history_dropped_events": "history_dropped_events", "stormlog_sink_rollover_count": "rollover_count", "stormlog_sink_pruned_segment_count": "pruned_segment_count", "stormlog_sink_pruned_bytes": "pruned_bytes", "stormlog_sink_retained_files": "final_retained_files", "stormlog_sink_retained_bytes": "final_retained_bytes", } metrics: dict[str, Any] = {} for wandb_key, stats_key in metric_names.items(): value = stats.get(stats_key) if isinstance(value, (int, float, bool)) and not isinstance(value, complex): metrics[wandb_key] = value return metrics def tracking_summary_fields( stats: Mapping[str, Any], *, output_path: str | Path | None, ) -> dict[str, Any]: fields: dict[str, Any] = {} for source_key, target_key in ( ("backend", "stormlog_backend"), ("collector_health_status", "stormlog_collector_health_status"), ("collector_last_error", "stormlog_collector_last_error"), ("session_status", "stormlog_session_status"), ): value = stats.get(source_key) if value is not None: fields[target_key] = value output_file = coerce_existing_file(output_path) if output_file is not None: fields["stormlog_output_file"] = output_file.name return fields def log_alerts_table(wandb: Any, run: Any, events: Sequence[Any]) -> None: rows: list[list[Any]] = [] for event in events: event_type = event_value(event, "event_type") or event_value(event, "type") if event_type not in _ALERT_EVENT_TYPES: continue rows.append( [ event_timestamp_seconds(event), event_type, event_value(event, "context"), event_int_value(event, "memory_allocated", "allocator_allocated_bytes"), event_int_value(event, "memory_reserved", "allocator_reserved_bytes"), event_int_value(event, "memory_change", "allocator_change_bytes"), event_value(event, "job_id"), event_value(event, "rank"), ] ) if not rows: return run.log( { "stormlog_alerts": wandb.Table( columns=[ "timestamp_s", "event_type", "context", "memory_allocated_bytes", "memory_reserved_bytes", "memory_change_bytes", "job_id", "rank", ], data=rows[-250:], ) } ) def log_tracking_time_series(run: Any, rows: Sequence[Mapping[str, Any]]) -> None: for row in rows: payload = { "stormlog_timeline_elapsed_seconds": row["elapsed_seconds"], "stormlog_timeline_allocated_bytes": row["allocated_bytes"], "stormlog_timeline_reserved_bytes": row["reserved_bytes"], "stormlog_timeline_change_bytes": row["change_bytes"], "stormlog_timeline_device_used_bytes": row["device_used_bytes"], "stormlog_timeline_utilization_percent": row["utilization_percent"], } filtered_payload = { key: value for key, value in payload.items() if value is not None } if filtered_payload: run.log(filtered_payload) def log_tracking_visualizations( wandb: Any, run: Any, rows: Sequence[Mapping[str, Any]], *, session_slug: str, dashboard_root: Path | None, allow_artifact_logging: bool, ) -> dict[str, Any]: if not rows: return {} run.log( { "stormlog_memory_timeline_table": wandb.Table( columns=[ "sample_index", "elapsed_seconds", "event_type", "memory_allocated_bytes", "memory_reserved_bytes", "memory_change_bytes", "device_used_bytes", "utilization_percent", "context", "rank", ], data=[ [ row["sample_index"], row["elapsed_seconds"], row["event_type"], row["allocated_bytes"], row["reserved_bytes"], row["change_bytes"], row["device_used_bytes"], row["utilization_percent"], row["context"], row["rank"], ] for row in rows ], ) } ) plot_api = getattr(wandb, "plot", None) line_series = getattr(plot_api, "line_series", None) if callable(line_series): elapsed = [float(row["elapsed_seconds"]) for row in rows] keys, ys = _memory_plot_series(rows) if keys: run.log( { "stormlog_memory_timeline_plot": line_series( xs=elapsed, ys=ys, keys=keys, title="Stormlog Memory Timeline", xname="Elapsed Seconds", ) } ) utilization_series = _series_for_plot(rows, "utilization_percent") if any(not math.isnan(value) for value in utilization_series): run.log( { "stormlog_memory_utilization_plot": line_series( xs=elapsed, ys=[utilization_series], keys=["utilization_percent"], title="Stormlog Memory Utilization", xname="Elapsed Seconds", ) } ) dashboard_html = tracking_dashboard_html(rows, alert_event_types=_ALERT_EVENT_TYPES) run.log({"stormlog_tracking_dashboard": wandb.Html(dashboard_html)}) if not allow_artifact_logging: return {} dashboard_path = materialize_html_file( html_text=dashboard_html, file_name="stormlog_tracking_dashboard.html", output_root=dashboard_root, ) log_file_artifact( wandb, run, artifact_name=f"stormlog-tracking-dashboard-{session_slug}", artifact_type="stormlog-tracking-dashboard", path=dashboard_path, ) return {"stormlog_tracking_dashboard_file": dashboard_path.name} def tracking_timeline_rows(events: Sequence[Any]) -> list[dict[str, Any]]: timeline_rows: list[dict[str, Any]] = [] first_timestamp: float | None = None for event in events: timestamp_s = event_timestamp_seconds(event) if timestamp_s is None: continue if first_timestamp is None: first_timestamp = timestamp_s allocated = event_int_value( event, "memory_allocated", "allocator_allocated_bytes" ) reserved = event_int_value(event, "memory_reserved", "allocator_reserved_bytes") change = event_int_value(event, "memory_change", "allocator_change_bytes") device_used = event_int_value(event, "device_used", "device_used_bytes") device_total = event_int_value(event, "device_total", "device_total_bytes") if device_used is None: candidates = [value for value in (allocated, reserved) if value is not None] device_used = max(candidates) if candidates else None utilization_percent: float | None = None if ( isinstance(device_used, int) and isinstance(device_total, int) and device_total > 0 ): utilization_percent = (float(device_used) / float(device_total)) * 100.0 timeline_rows.append( { "sample_index": len(timeline_rows), "elapsed_seconds": timestamp_s - first_timestamp, "event_type": str( event_value(event, "event_type") or event_value(event, "type") or "sample" ), "allocated_bytes": allocated, "reserved_bytes": reserved, "change_bytes": change, "device_used_bytes": device_used, "utilization_percent": utilization_percent, "context": event_value(event, "context"), "rank": event_value(event, "rank"), } ) return sample_timeline_rows(timeline_rows) def sample_timeline_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: if len(rows) <= _TIMELINE_MAX_POINTS: return list(rows) pinned_by_index: dict[int, dict[str, Any]] = {} for row in rows: sample_index = row.get("sample_index") if ( isinstance(sample_index, int) and row.get("event_type") in _ALERT_EVENT_TYPES ): pinned_by_index[sample_index] = row last_row = rows[-1] last_index = last_row.get("sample_index") if isinstance(last_index, int): pinned_by_index[last_index] = last_row pinned_rows = sorted( pinned_by_index.values(), key=lambda row: int(row["sample_index"]), ) remaining_budget = _TIMELINE_MAX_POINTS - len(pinned_rows) if remaining_budget <= 0: return pinned_rows[-_TIMELINE_MAX_POINTS:] pinned_indices = set(pinned_by_index) unpinned_rows = [ row for row in rows if row.get("sample_index") not in pinned_indices ] stride = max(1, math.ceil(len(unpinned_rows) / remaining_budget)) sampled_rows = unpinned_rows[::stride][:remaining_budget] return sorted( [*pinned_rows, *sampled_rows], key=lambda row: int(row["sample_index"]), )[:_TIMELINE_MAX_POINTS] def _memory_plot_series( rows: Sequence[Mapping[str, Any]], ) -> tuple[list[str], list[list[float]]]: keys: list[str] = [] series: list[list[float]] = [] for row_key, label in ( ("allocated_bytes", "allocated_bytes"), ("reserved_bytes", "reserved_bytes"), ("device_used_bytes", "device_used_bytes"), ): values = _series_for_plot(rows, row_key) if any(not math.isnan(value) for value in values): keys.append(label) series.append(values) return keys, series def _series_for_plot(rows: Sequence[Mapping[str, Any]], key: str) -> list[float]: values: list[float] = [] for row in rows: value = row.get(key) if isinstance(value, (int, float)) and not isinstance(value, bool): values.append(float(value)) else: values.append(math.nan) return values def event_value(event: Any, name: str) -> Any: if isinstance(event, Mapping): return event.get(name) return getattr(event, name, None) def event_int_value(event: Any, *names: str) -> int | None: for name in names: value = event_value(event, name) if isinstance(value, int) and not isinstance(value, bool): return int(value) return None def event_timestamp_seconds(event: Any) -> float | None: value = event_value(event, "timestamp") if isinstance(value, (int, float)) and not isinstance(value, bool): return float(value) value_ns = event_value(event, "timestamp_ns") if isinstance(value_ns, int) and not isinstance(value_ns, bool): return float(value_ns) / 1_000_000_000.0 return None def _tracking_dashboard_root( output_file: Path | None, sink_dir: Path | None, ) -> Path | None: if output_file is not None: return output_file.parent if sink_dir is not None: return sink_dir return None