Source code for stormlog.telemetry_model

"""Backend-neutral projection over the persisted telemetry event schema."""

from __future__ import annotations

import hashlib
import json
from collections.abc import Mapping as MappingABC
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Iterable, Literal, Mapping, Optional, cast

# Bump this only with the docs, dataclass annotation, serialization tests, and
# compatibility behavior for the projected envelope.
TELEMETRY_PROJECTION_SCHEMA_VERSION: Literal[1] = 1

_SUPPORTED_METADATA_SOURCE_KINDS = frozenset(
    {"cpu", "cuda", "rocm", "mps", "tensorflow"}
)
_INFERRED_COLLECTOR_SOURCE_KINDS = ("cuda", "rocm", "mps", "cpu", "tensorflow")

_SEVERITY_BY_EVENT_TYPE = {
    "critical": "critical",
    "error": "error",
    "warning": "warning",
    "collector_degraded": "warning",
    "collector_recovered": "info",
    "start": "info",
    "stop": "info",
    "phase_enter": "info",
    "phase_exit": "info",
    "sample": "info",
}


[docs] @dataclass(frozen=True) class ProjectedTelemetryRecord: """Small immutable event envelope shared by live and loaded telemetry.""" schema_version: Literal[1] record_id: str timestamp_ns: int observed_timestamp_ns: int session_id: str source_kind: str event_type: str stage: Optional[str] severity: Optional[str] severity_text: Optional[str] body: Optional[str] resource: Mapping[str, Any] = field(default_factory=dict) attributes: Mapping[str, Any] = field(default_factory=dict) correlation: Mapping[str, Any] = field(default_factory=dict)
def _freeze_value(value: Any) -> Any: if isinstance(value, MappingABC): return MappingProxyType( {str(key): _freeze_value(nested) for key, nested in value.items()} ) if isinstance(value, (list, tuple)): return tuple(_freeze_value(item) for item in value) return value def _freeze_mapping(payload: Mapping[str, Any]) -> Mapping[str, Any]: return MappingProxyType( {str(key): _freeze_value(value) for key, value in payload.items()} ) def _thaw_value(value: Any) -> Any: if isinstance(value, MappingABC): return {str(key): _thaw_value(nested) for key, nested in value.items()} if isinstance(value, tuple): return [_thaw_value(item) for item in value] return value def _jsonable_mapping(payload: Mapping[str, Any]) -> dict[str, Any]: result = json.loads(json.dumps(_thaw_value(payload), sort_keys=True, default=str)) return cast(dict[str, Any], result) def _record_id(record: Mapping[str, Any]) -> str: digest = hashlib.sha256( json.dumps(_jsonable_mapping(record), sort_keys=True).encode("utf-8") ).hexdigest() # The 32-character prefix keeps a compact 128-bit identity while preserving # deterministic grouping for the projected envelope. return f"telemetry-{digest[:32]}" def _source_kind(record: Mapping[str, Any], metadata: Mapping[str, Any]) -> str: backend = metadata.get("backend") if isinstance(backend, str) and backend.strip(): normalized_backend = backend.strip().lower() if normalized_backend in _SUPPORTED_METADATA_SOURCE_KINDS: return normalized_backend collector = record.get("collector") if isinstance(collector, str): lowered = collector.lower() for candidate in _INFERRED_COLLECTOR_SOURCE_KINDS: if candidate in lowered: return candidate device_id = record.get("device_id") if isinstance(device_id, int) and device_id >= 0: return "gpu" return "other" def _phase_stage(metadata: Mapping[str, Any]) -> Optional[str]: phase_scope = metadata.get("phase_scope") if not isinstance(phase_scope, Mapping): return None name = phase_scope.get("name") if isinstance(name, str) and name.strip(): return name path = phase_scope.get("path") if isinstance(path, list) and path: tail = path[-1] if isinstance(tail, str) and tail.strip(): return tail return None def _severity(record: Mapping[str, Any], metadata: Mapping[str, Any]) -> Optional[str]: explicit = metadata.get("severity") if isinstance(explicit, str) and explicit.strip(): return explicit.strip().lower() event_type = record.get("event_type") if isinstance(event_type, str): return _SEVERITY_BY_EVENT_TYPE.get(event_type.strip().lower()) return None def _resource(record: Mapping[str, Any], source_kind: str) -> dict[str, Any]: return { "source_kind": source_kind, "collector": record.get("collector"), "host": record.get("host"), "pid": record.get("pid"), "device_id": record.get("device_id"), "job_id": record.get("job_id"), "rank": record.get("rank"), "local_rank": record.get("local_rank"), "world_size": record.get("world_size"), } def _attributes( record: Mapping[str, Any], metadata: Mapping[str, Any], ) -> dict[str, Any]: attributes = dict(metadata) attributes.update( { "sampling.interval_ms": record.get("sampling_interval_ms"), "memory.allocator.allocated_bytes": record.get("allocator_allocated_bytes"), "memory.allocator.reserved_bytes": record.get("allocator_reserved_bytes"), "memory.allocator.active_bytes": record.get("allocator_active_bytes"), "memory.allocator.inactive_bytes": record.get("allocator_inactive_bytes"), "memory.allocator.change_bytes": record.get("allocator_change_bytes"), "memory.device.used_bytes": record.get("device_used_bytes"), "memory.device.free_bytes": record.get("device_free_bytes"), "memory.device.total_bytes": record.get("device_total_bytes"), } ) return attributes def _correlation( record: Mapping[str, Any], metadata: Mapping[str, Any] ) -> dict[str, Any]: correlation = { "session_id": record.get("session_id"), "job_id": record.get("job_id"), "rank": record.get("rank"), "local_rank": record.get("local_rank"), "world_size": record.get("world_size"), } phase_scope = metadata.get("phase_scope") if isinstance(phase_scope, Mapping): for key in ("scope_id", "parent_scope_id", "sequence"): value = phase_scope.get(key) if value is not None: correlation[f"phase.{key}"] = value return correlation
[docs] def project_telemetry_mapping( record: Mapping[str, Any], *, observed_timestamp_ns: int | None = None, ) -> ProjectedTelemetryRecord: """Project a normalized telemetry mapping into the projected envelope. Args: record: Existing normalized telemetry record, normally a `TelemetryEvent v3` dictionary. observed_timestamp_ns: Optional observation timestamp. Defaults to the source timestamp when no separate observation time is available. Returns: Backend-neutral projected telemetry record. Raises: ValueError: If required identity or timestamp fields are missing. """ session_id = record.get("session_id") if not isinstance(session_id, str) or not session_id.strip(): raise ValueError("projected telemetry record requires session_id") timestamp_ns = record.get("timestamp_ns") if not isinstance(timestamp_ns, int) or isinstance(timestamp_ns, bool): raise ValueError("projected telemetry record requires integer timestamp_ns") if observed_timestamp_ns is not None: if not isinstance(observed_timestamp_ns, int) or isinstance( observed_timestamp_ns, bool ): raise ValueError( "projected telemetry record requires integer observed_timestamp_ns" ) if observed_timestamp_ns < 0: raise ValueError( "projected telemetry record requires non-negative observed_timestamp_ns" ) event_type = record.get("event_type") if not isinstance(event_type, str) or not event_type.strip(): raise ValueError("projected telemetry record requires event_type") metadata_value = record.get("metadata", {}) if not isinstance(metadata_value, Mapping): raise ValueError("projected telemetry record metadata must be a mapping") metadata = dict(metadata_value) resolved_source_kind = _source_kind(record, metadata) body = record.get("context") return ProjectedTelemetryRecord( schema_version=TELEMETRY_PROJECTION_SCHEMA_VERSION, record_id=_record_id(record), timestamp_ns=timestamp_ns, observed_timestamp_ns=( timestamp_ns if observed_timestamp_ns is None else observed_timestamp_ns ), session_id=session_id, source_kind=resolved_source_kind, event_type=event_type, stage=_phase_stage(metadata), severity=_severity(record, metadata), severity_text=( str(metadata["severity_text"]) if metadata.get("severity_text") is not None else None ), body=str(body) if body is not None else None, resource=_freeze_mapping(_resource(record, resolved_source_kind)), attributes=_freeze_mapping(_attributes(record, metadata)), correlation=_freeze_mapping(_correlation(record, metadata)), )
[docs] def projected_record_to_dict(record: ProjectedTelemetryRecord) -> dict[str, Any]: """Serialize a projected telemetry record to a deterministic dictionary.""" return { "schema_version": record.schema_version, "record_id": record.record_id, "timestamp_ns": record.timestamp_ns, "observed_timestamp_ns": record.observed_timestamp_ns, "session_id": record.session_id, "source_kind": record.source_kind, "event_type": record.event_type, "stage": record.stage, "severity": record.severity, "severity_text": record.severity_text, "body": record.body, "resource": _thaw_value(record.resource), "attributes": _thaw_value(record.attributes), "correlation": _thaw_value(record.correlation), }
[docs] def unique_projected_resources( records: Iterable[ProjectedTelemetryRecord], ) -> list[dict[str, Any]]: """Return stable unique resource dictionaries for projected telemetry records.""" seen: set[str] = set() resources: list[dict[str, Any]] = [] for record in records: resource = cast(dict[str, Any], _thaw_value(record.resource)) key = json.dumps(_jsonable_mapping(resource), sort_keys=True) if key in seen: continue seen.add(key) resources.append(resource) return resources
[docs] def unique_projected_correlations( records: Iterable[ProjectedTelemetryRecord], ) -> list[dict[str, Any]]: """Return stable unique correlation dictionaries for projected telemetry records.""" seen: set[str] = set() correlations: list[dict[str, Any]] = [] for record in records: correlation = cast(dict[str, Any], _thaw_value(record.correlation)) key = json.dumps(_jsonable_mapping(correlation), sort_keys=True) if key in seen: continue seen.add(key) correlations.append(correlation) return correlations
__all__ = [ "TELEMETRY_PROJECTION_SCHEMA_VERSION", "ProjectedTelemetryRecord", "project_telemetry_mapping", "projected_record_to_dict", "unique_projected_correlations", "unique_projected_resources", ]