"""JAX Stormlog CLI"""
import argparse
import json
import logging
import sys
import time
from pathlib import Path
from typing import Any, Dict
from stormlog.telemetry import telemetry_event_from_record, telemetry_event_to_dict
from stormlog.telemetry_sink import TelemetrySinkConfig
try:
from stormlog.wandb_integration import (
add_wandb_arguments,
ensure_wandb_available,
export_diagnose_bundle_to_wandb,
export_tracking_run_to_wandb,
wandb_config_from_namespace,
)
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
def add_wandb_arguments(parser: Any) -> None:
pass
def wandb_config_from_namespace(args: Any) -> Any: # type: ignore[misc]
class DummyConfig:
enabled = False
return DummyConfig()
from .jax_env import configure_jax_logging
from .utils import format_memory, get_system_info
configure_jax_logging()
jax: Any
try:
import jax as _jax
jax = _jax
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
jax = None
if JAX_AVAILABLE:
from .diagnose import run_diagnose
from .tracker import MemoryTracker
[docs]
def setup_logging(verbose: bool = False) -> None:
"""Setup logging configuration."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def _normalize_telemetry_events(
records: list[dict[str, Any]], sampling_interval_ms: int
) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for record in records:
event = telemetry_event_from_record(
record,
default_collector="stormlog.jax.memory_tracker",
default_sampling_interval_ms=sampling_interval_ms,
)
normalized.append(telemetry_event_to_dict(event))
return normalized
def _build_telemetry_sink_config(
args: argparse.Namespace,
) -> TelemetrySinkConfig | None:
sink_dir = getattr(args, "telemetry_sink_dir", None)
if not sink_dir:
return None
return TelemetrySinkConfig(
root_dir=Path(sink_dir),
flush_every_seconds=float(getattr(args, "telemetry_flush_seconds", 2.0)),
rollover_max_bytes=int(getattr(args, "telemetry_rollover_mb", 64))
* 1024
* 1024,
retention_max_files=int(getattr(args, "telemetry_retention_files", 8)),
retention_max_total_bytes=int(
getattr(args, "telemetry_retention_total_mb", 512)
)
* 1024
* 1024,
)
def _resolve_wandb_config(args: argparse.Namespace) -> Any:
config = wandb_config_from_namespace(args)
if not config.enabled:
return config
try:
ensure_wandb_available(config)
except ImportError as exc:
print(f"Error: {exc}", file=sys.stderr)
return None
return config
def _warn_wandb_export_failure(command_name: str, exc: Exception) -> None:
print(f"Warning: {command_name} W&B export skipped: {exc}", file=sys.stderr)
[docs]
def cmd_info(args: argparse.Namespace) -> int:
"""Display system and device information."""
print("JAX Stormlog - System Information")
print("=" * 50)
system_info = get_system_info()
print(f"Platform: {system_info['platform']}")
print(f"Python Version: {system_info['python_version']}")
print(f"CPU Count: {system_info['cpu_count']}")
if "total_memory_gb" in system_info:
print(f"Total System Memory: {system_info['total_memory_gb']:.2f} GB")
if "available_memory_gb" in system_info:
print(f"Available Memory: {system_info['available_memory_gb']:.2f} GB")
backend_info = system_info.get("backend", {})
device_count = backend_info.get("device_count", 0)
print("\nJAX Backend Information:")
print("-" * 30)
print(f"Runtime Backend: {backend_info.get('runtime_backend', 'Unknown')}")
print(f"Is XLA GPU Build: {backend_info.get('is_gpu_build', False)}")
print(f"Is Apple Silicon: {backend_info.get('is_apple_silicon', False)}")
print(f"JAX Metal Installed: {backend_info.get('jax_metal_installed', False)}")
print(f"Available Devices: {device_count}")
if device_count > 0:
print("\nDevice Information:")
print("-" * 20)
from .utils import get_device_info
for i in range(device_count):
device_info = get_device_info(i)
print(f"\nDevice {i}:")
print(f" Name: {device_info.get('kind', 'Unknown')}")
stats = device_info.get("memory_stats", {})
allocated_bytes = stats.get("bytes_in_use", 0)
reserved_bytes = stats.get("bytes_reserved")
peak_bytes = stats.get("peak_bytes_in_use", 0)
limit_bytes = stats.get("bytes_limit", 0)
print(f" Allocated Memory: {format_memory(allocated_bytes)}")
if reserved_bytes is not None:
print(f" Reserved Memory: {format_memory(reserved_bytes)}")
print(f" Peak Memory: {format_memory(peak_bytes)}")
if limit_bytes:
print(f" Device Limit: {format_memory(limit_bytes)}")
return 0
[docs]
def cmd_monitor(args: argparse.Namespace) -> int:
"""Monitor device memory usage in real-time."""
if not JAX_AVAILABLE:
print("Error: JAX not available")
return 1
print("Starting JAX memory monitoring...")
print(f"Sampling interval: {args.interval} seconds")
print(
f"Duration: {args.duration} seconds"
if args.duration
else "Duration: Indefinite"
)
if args.threshold:
print(f"Alert threshold: {args.threshold} MB")
print("Press Ctrl+C to stop\n")
max_history = int(getattr(args, "max_history", 10000))
tracker = MemoryTracker(
sampling_interval=args.interval,
alert_threshold_mb=args.threshold,
device_index=args.device,
enable_logging=True,
max_history=max_history,
)
try:
tracker.start_tracking()
start_time = time.time()
while True:
if args.duration and (time.time() - start_time) >= args.duration:
break
current_memory = tracker.get_current_memory()
print(
f"\rCurrent memory usage: {current_memory:.1f} MB", end="", flush=True
)
time.sleep(1.0)
except KeyboardInterrupt:
print("\n\nStopping monitoring...")
finally:
results = tracker.stop_tracking()
print("\nMonitoring Results:")
print("-" * 20)
print(f"Peak Memory: {results.peak_memory_bytes / (1024 * 1024):.1f} MB")
print(f"Average Memory: {results.average_memory_bytes / (1024 * 1024):.1f} MB")
print(f"Duration: {results.duration:.1f} seconds")
print(f"Samples Collected: {len(results.memory_usage)}")
dropped_samples = getattr(results, "history_dropped_samples", 0)
if dropped_samples:
print(f"Dropped Samples: {dropped_samples}")
if results.alert_count:
print(f"Alerts Triggered: {results.alert_count}")
if args.output:
# Save results
output_data = {
"peak_memory": results.peak_memory_bytes / (1024 * 1024),
"average_memory": results.average_memory_bytes / (1024 * 1024),
"duration": results.duration,
"memory_usage": results.memory_usage,
"timestamps": results.timestamps,
"alerts": results.alert_count,
"history_window_limit": getattr(results, "history_window_limit", 0),
"history_retained_samples": getattr(
results, "history_retained_samples", 0
),
"history_dropped_samples": getattr(
results, "history_dropped_samples", 0
),
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2)
print(f"Results saved to {args.output}")
return 0
[docs]
def cmd_track(args: argparse.Namespace) -> int:
"""Start background memory tracking."""
if not JAX_AVAILABLE:
print("Error: JAX not available")
return 1
wandb_config = _resolve_wandb_config(args)
if wandb_config is None:
return 1
print("Starting background memory tracking...")
job_id = getattr(args, "job_id", None)
rank = getattr(args, "rank", None)
local_rank = getattr(args, "local_rank", None)
world_size = getattr(args, "world_size", None)
telemetry_sink_config = _build_telemetry_sink_config(args)
oom_flight_recorder = bool(getattr(args, "oom_flight_recorder", False))
oom_dump_dir = str(getattr(args, "oom_dump_dir", "oom_dumps"))
oom_buffer_size = getattr(args, "oom_buffer_size", None)
oom_max_dumps = int(getattr(args, "oom_max_dumps", 5))
oom_max_total_mb = int(getattr(args, "oom_max_total_mb", 256))
max_history = int(getattr(args, "max_history", 10000))
tracker = MemoryTracker(
sampling_interval=args.interval,
alert_threshold_mb=args.threshold,
device_index=args.device,
enable_logging=True,
job_id=job_id,
rank=rank,
local_rank=local_rank,
world_size=world_size,
telemetry_sink_config=telemetry_sink_config,
save_device_profile_on_stop=args.profile,
max_history=max_history,
enable_oom_flight_recorder=oom_flight_recorder,
oom_dump_dir=oom_dump_dir,
oom_buffer_size=oom_buffer_size,
oom_max_dumps=oom_max_dumps,
oom_max_total_mb=oom_max_total_mb,
)
if telemetry_sink_config is not None:
print(f"Append-only telemetry sink: {telemetry_sink_config.root_dir}")
if oom_flight_recorder:
print("OOM flight recorder enabled:")
print(f" Dump directory: {oom_dump_dir}")
print(f" Buffer size: {tracker.oom_buffer_size} events")
print(f" Max dumps: {oom_max_dumps}")
print(f" Max total size: {oom_max_total_mb} MB")
def alert_callback(alert: Dict[str, Any]) -> None:
print(f"\n⚠️ MEMORY ALERT: {alert['message']}")
tracker.add_alert_callback(alert_callback)
try:
tracker.start_tracking()
print("Tracking started. Press Ctrl+C to stop and save results.")
with tracker.capture_oom(
context="jaxmemprof.track",
metadata={"command": "track", "runtime_backend": "jax"},
):
while True:
time.sleep(5.0)
stats = tracker.get_statistics()
current_memory = stats.get("current_memory_mb")
collector_health = str(stats.get("collector_health_status", "healthy"))
if isinstance(current_memory, (int, float)):
status_line = f"Current memory: {float(current_memory):.1f} MB"
else:
status_line = "Current memory: unavailable"
status_line += f" | Health: {collector_health}"
retry_at = stats.get("collector_next_retry_epoch_s")
if isinstance(retry_at, (int, float)):
retry_in = max(float(retry_at) - time.time(), 0.0)
status_line += f" | Retry In: {retry_in:.1f}s"
print(status_line)
except KeyboardInterrupt:
print("\nStopping tracking...")
finally:
results = tracker.stop_tracking()
if args.output:
sampling_interval_ms = int(round(args.interval * 1000))
output_data = {
"peak_memory": results.peak_memory_bytes / (1024 * 1024),
"average_memory": results.average_memory_bytes / (1024 * 1024),
"duration": results.duration,
"memory_usage": results.memory_usage,
"timestamps": results.timestamps,
"alerts": results.alert_count,
"events": _normalize_telemetry_events(
results.telemetry_events,
sampling_interval_ms=sampling_interval_ms,
),
"history_window_limit": getattr(results, "history_window_limit", 0),
"history_retained_samples": getattr(
results, "history_retained_samples", 0
),
"history_dropped_samples": getattr(
results, "history_dropped_samples", 0
),
"history_retained_events": getattr(
results, "history_retained_events", 0
),
"history_dropped_events": getattr(results, "history_dropped_events", 0),
"history_retained_alerts": getattr(
results, "history_retained_alerts", 0
),
"history_dropped_alerts": getattr(results, "history_dropped_alerts", 0),
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2)
print(f"Results saved to {args.output}")
final_stats = tracker.get_statistics()
print(
f"\nTracking completed. Peak memory: {results.peak_memory_bytes / (1024 * 1024):.1f} MB"
)
dropped_samples = int(final_stats.get("history_dropped_samples", 0))
dropped_events = int(final_stats.get("history_dropped_events", 0))
if dropped_samples or dropped_events:
print(
"History truncation: "
f"samples={dropped_samples}, events={dropped_events}"
)
if "collector_health_status" in final_stats:
print(
"Collector health: "
f"{final_stats.get('collector_health_status', 'healthy')}"
)
if final_stats.get("collector_last_error"):
print(f"Last collector error: {final_stats.get('collector_last_error')}")
if results.device_memory_profile_path:
print(
f"Device memory profile saved to: {results.device_memory_profile_path}"
)
if tracker.last_oom_dump_path:
print(f"OOM flight recorder dump saved to: {tracker.last_oom_dump_path}")
if WANDB_AVAILABLE and wandb_config.enabled:
try:
export_tracking_run_to_wandb(
wandb_config,
command_name="jaxmemprof-track",
session_summary=tracker.get_session_summary(),
stats=final_stats,
events=results.telemetry_events,
output_path=args.output,
telemetry_sink_dir=getattr(args, "telemetry_sink_dir", None),
oom_dump_path=tracker.last_oom_dump_path,
)
print("W&B export completed.")
except Exception as exc:
_warn_wandb_export_failure("jaxmemprof track", exc)
elif wandb_config.enabled:
print("Warning: W&B is not available.", file=sys.stderr)
return 0
[docs]
def cmd_diagnose(args: argparse.Namespace) -> int:
"""Produce a portable diagnostic bundle. Returns 0 (OK), 1 (failure), or 2 (memory risk)."""
if not JAX_AVAILABLE:
print("Error: JAX not available")
return 1
if args.duration < 0:
print("Error: --duration must be >= 0", file=sys.stderr)
return 1
if args.interval <= 0:
print("Error: --interval must be > 0", file=sys.stderr)
return 1
wandb_config = _resolve_wandb_config(args)
if wandb_config is None:
return 1
command_line = " ".join(sys.argv)
try:
artifact_dir, exit_code = run_diagnose(
output=args.output,
device_index=args.device,
duration=args.duration,
interval=args.interval,
command_line=command_line,
)
except OSError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
# Structured stdout summary
print(f"Artifact: {artifact_dir}")
if exit_code == 0:
status = "OK"
elif exit_code == 2:
status = "MEMORY_RISK"
else:
status = "FAILED"
print(f"Status: {status} (exit_code={exit_code})")
try:
manifest_path = artifact_dir / "manifest.json"
if manifest_path.exists():
with open(manifest_path, encoding="utf-8") as f:
manifest = json.load(f)
if manifest.get("risk_detected"):
summary_path = artifact_dir / "diagnostic_summary.json"
if summary_path.exists():
with open(summary_path, encoding="utf-8") as f:
summary = json.load(f)
flags = summary.get("risk_flags", {})
parts = [k for k, v in flags.items() if v]
if parts:
print(f"Findings: {', '.join(parts)}")
if exit_code == 0 and status == "OK":
print("Findings: no memory risk detected")
except (OSError, json.JSONDecodeError):
pass
if WANDB_AVAILABLE and wandb_config.enabled:
try:
export_diagnose_bundle_to_wandb(
wandb_config,
command_name="jaxmemprof-diagnose",
artifact_dir=artifact_dir,
)
print("W&B export completed.")
except Exception as exc:
_warn_wandb_export_failure("jaxmemprof diagnose", exc)
elif wandb_config.enabled:
print("Warning: W&B is not available.", file=sys.stderr)
return exit_code
[docs]
def cmd_analyze(args: argparse.Namespace) -> int:
"""Analyze a saved tracking result JSON file."""
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: Input file {args.input} not found", file=sys.stderr)
return 1
try:
with input_path.open("r", encoding="utf-8") as f:
data = json.load(f)
except Exception as e:
print(f"Error: Failed to load results from {args.input}: {e}", file=sys.stderr)
return 1
print(f"Analyzing JAX tracking results: {args.input}")
print("=" * 50)
print(f"Peak Memory: {data.get('peak_memory', 0.0):.2f} MB")
print(f"Average Memory: {data.get('average_memory', 0.0):.2f} MB")
print(f"Duration: {data.get('duration', 0.0):.2f} seconds")
print(f"Alerts Triggered: {data.get('alerts', 0)}")
if args.plot:
from .visualizer import MemoryVisualizer
visualizer = MemoryVisualizer()
# Wrap data in a simple object for the visualizer
class ResultWrapper:
def __init__(self, d: Dict[str, Any]):
self.memory_usage = d.get("memory_usage", [])
self.timestamps = d.get("timestamps", [])
wrapper = ResultWrapper(data)
plot_name = args.plot if isinstance(args.plot, str) else "memory_timeline.png"
if args.output:
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
plot_path = str(output_dir / plot_name)
else:
plot_path = plot_name
visualizer.plot_memory_timeline(wrapper, save_path=plot_path)
print(f"Memory timeline plot saved to {plot_path}")
return 0
[docs]
def main() -> int:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
description="JAX Stormlog CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Cookbook:
https://stormlog.readthedocs.io/en/latest/cookbook/index.html
""",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Enable verbose logging"
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Info command
subparsers.add_parser("info", help="Display system and device information")
# Monitor command
monitor_parser = subparsers.add_parser(
"monitor", help="Monitor device memory usage"
)
monitor_parser.add_argument(
"--interval",
type=float,
default=1.0,
help="Sampling interval in seconds (default: 1.0)",
)
monitor_parser.add_argument(
"--duration",
type=float,
help="Monitoring duration in seconds (default: indefinite)",
)
monitor_parser.add_argument(
"--threshold", type=float, help="Memory alert threshold in MB"
)
monitor_parser.add_argument(
"--device",
type=int,
default=0,
help="JAX device index to monitor (default: 0)",
)
monitor_parser.add_argument("--output", help="Output file for results")
monitor_parser.add_argument(
"--max-history",
type=int,
default=10000,
help="Maximum number of historical samples to keep in memory (default: 10000)",
)
# Track command
track_parser = subparsers.add_parser("track", help="Background memory tracking")
track_parser.add_argument(
"--interval",
type=float,
default=1.0,
help="Sampling interval in seconds (default: 1.0)",
)
track_parser.add_argument(
"--threshold",
type=float,
default=4000,
help="Memory alert threshold in MB (default: 4000)",
)
track_parser.add_argument(
"--device",
type=int,
default=0,
help="JAX device index to monitor (default: 0)",
)
track_parser.add_argument(
"--profile",
action="store_true",
help="Export a JAX device memory profile upon tracking completion",
)
track_parser.add_argument(
"--job-id",
default=None,
help="Distributed job identifier override (default: infer from env)",
)
track_parser.add_argument(
"--rank",
type=int,
default=None,
help="Global distributed rank override (default: infer from env)",
)
track_parser.add_argument(
"--local-rank",
type=int,
default=None,
help="Local distributed rank override (default: infer from env)",
)
track_parser.add_argument(
"--world-size",
type=int,
default=None,
help="Distributed world size override (default: infer from env)",
)
track_parser.add_argument(
"--output", required=True, help="Output file for tracking results"
)
track_parser.add_argument(
"--telemetry-sink-dir",
default=None,
help="Directory for append-only telemetry sink segments",
)
track_parser.add_argument(
"--telemetry-flush-seconds",
type=float,
default=2.0,
help="Maximum seconds between telemetry sink flushes (default: 2.0)",
)
track_parser.add_argument(
"--telemetry-rollover-mb",
type=int,
default=64,
help="Telemetry sink segment rollover size in MB (default: 64)",
)
track_parser.add_argument(
"--telemetry-retention-files",
type=int,
default=8,
help="Maximum retained telemetry sink segments (default: 8)",
)
track_parser.add_argument(
"--telemetry-retention-total-mb",
type=int,
default=512,
help="Maximum retained telemetry sink size in MB (default: 512)",
)
track_parser.add_argument(
"--max-history",
type=int,
default=10000,
help="Maximum number of historical samples to keep in memory (default: 10000)",
)
track_parser.add_argument(
"--oom-flight-recorder",
action="store_true",
help="Enable automatic OOM flight recorder dump artifacts",
)
track_parser.add_argument(
"--oom-dump-dir",
default="oom_dumps",
help="Directory used to write OOM dump bundles (default: oom_dumps)",
)
track_parser.add_argument(
"--oom-buffer-size",
type=int,
default=None,
help="Ring buffer size for OOM event dumps (default: max history)",
)
track_parser.add_argument(
"--oom-max-dumps",
type=int,
default=5,
help="Maximum number of retained OOM dump bundles (default: 5)",
)
track_parser.add_argument(
"--oom-max-total-mb",
type=int,
default=256,
help="Maximum retained OOM dump storage in MB (default: 256)",
)
add_wandb_arguments(track_parser)
# Diagnose command
diagnose_parser = subparsers.add_parser(
"diagnose", help="Diagnose OOM dumps and memory issues"
)
diagnose_parser.add_argument(
"--output",
type=str,
default=None,
help="Output directory for the artifact bundle (default: cwd)",
)
diagnose_parser.add_argument(
"--device",
type=int,
default=0,
help="JAX device index to monitor (default: 0)",
)
diagnose_parser.add_argument(
"--duration",
type=float,
default=5.0,
help="Seconds to run tracker for telemetry (default: 5, use 0 to skip)",
)
diagnose_parser.add_argument(
"--interval",
type=float,
default=0.5,
help="Sampling interval for timeline (default: 0.5)",
)
add_wandb_arguments(diagnose_parser)
# Analyze command
analyze_parser = subparsers.add_parser(
"analyze", help="Analyze saved tracking results"
)
analyze_parser.add_argument(
"--input", required=True, help="Input JSON file with tracking results"
)
analyze_parser.add_argument(
"--output",
type=str,
default=None,
help="Output directory for analysis results",
)
analyze_parser.add_argument(
"--plot",
nargs="?",
const="memory_timeline.png",
help="Generate a memory usage plot (default: memory_timeline.png)",
)
args = parser.parse_args()
setup_logging(args.verbose)
if not args.command:
parser.print_help()
return 0
# Execute command
if args.command == "info":
return cmd_info(args)
elif args.command == "monitor":
return cmd_monitor(args)
elif args.command == "track":
return cmd_track(args)
elif args.command == "diagnose":
return cmd_diagnose(args)
elif args.command == "analyze":
return cmd_analyze(args)
else:
print(f"Unknown command: {args.command}")
return 1
if __name__ == "__main__":
sys.exit(main())