Source code for stormlog.tensorflow.cli

"""TensorFlow Stormlog CLI"""

import argparse
import json
import logging
import sys
import time
from pathlib import Path
from typing import Any, Dict

from .tf_env import configure_tensorflow_logging

configure_tensorflow_logging()

try:
    import tensorflow as tf

    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    tf = None

from stormlog.telemetry import telemetry_event_from_record, telemetry_event_to_dict
from stormlog.telemetry_sink import TelemetrySinkConfig
from stormlog.wandb_integration import (
    add_wandb_arguments,
    ensure_wandb_available,
    export_diagnose_bundle_to_wandb,
    export_tracking_run_to_wandb,
    wandb_config_from_namespace,
)

from .analyzer import MemoryAnalyzer
from .diagnose import run_diagnose
from .tracker import MemoryTracker
from .utils import format_memory, generate_summary_report, get_system_info
from .visualizer import MemoryVisualizer


[docs] def setup_logging(verbose: bool = False) -> None: """Setup logging configuration.""" level = logging.DEBUG if verbose else logging.INFO logging.basicConfig( level=level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", )
def _normalize_telemetry_events( records: list[dict[str, Any]], sampling_interval_ms: int ) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for record in records: event = telemetry_event_from_record( record, default_collector="stormlog.tensorflow.memory_tracker", default_sampling_interval_ms=sampling_interval_ms, ) normalized.append(telemetry_event_to_dict(event)) return normalized def _build_telemetry_sink_config( args: argparse.Namespace, ) -> TelemetrySinkConfig | None: sink_dir = getattr(args, "telemetry_sink_dir", None) if not sink_dir: return None return TelemetrySinkConfig( root_dir=Path(sink_dir), flush_every_seconds=float(getattr(args, "telemetry_flush_seconds", 2.0)), rollover_max_bytes=int(getattr(args, "telemetry_rollover_mb", 64)) * 1024 * 1024, retention_max_files=int(getattr(args, "telemetry_retention_files", 8)), retention_max_total_bytes=int( getattr(args, "telemetry_retention_total_mb", 512) ) * 1024 * 1024, ) def _resolve_wandb_config(args: argparse.Namespace) -> Any: config = wandb_config_from_namespace(args) if not config.enabled: return config try: ensure_wandb_available(config) except ImportError as exc: print(f"Error: {exc}", file=sys.stderr) return None return config def _warn_wandb_export_failure(command_name: str, exc: Exception) -> None: print(f"Warning: {command_name} W&B export skipped: {exc}", file=sys.stderr)
[docs] def cmd_info(args: argparse.Namespace) -> int: """Display system and GPU information.""" print("TensorFlow Stormlog - System Information") print("=" * 50) system_info = get_system_info() print(f"Platform: {system_info['platform']}") print(f"Python Version: {system_info['python_version']}") print(f"TensorFlow Version: {system_info['tensorflow_version']}") print(f"CPU Count: {system_info['cpu_count']}") if "total_memory_gb" in system_info: print(f"Total System Memory: {system_info['total_memory_gb']:.2f} GB") print(f"Available Memory: {system_info['available_memory_gb']:.2f} GB") print("\nGPU Information:") print("-" * 20) gpu_info = system_info.get("gpu", {}) backend_info = system_info.get("backend", {}) if gpu_info.get("available", False): print("GPU Available: Yes") print(f"GPU Count: {gpu_info['count']}") print( f"Total GPU Memory: {format_memory(gpu_info['total_memory'] * 1024 * 1024)}" ) for i, device in enumerate(gpu_info.get("devices", [])): print(f"\nGPU {i}:") print(f" Name: {device.get('name', 'Unknown')}") print(f" Current Memory: {device.get('current_memory_mb', 0):.1f} MB") print(f" Peak Memory: {device.get('peak_memory_mb', 0):.1f} MB") else: print( f"GPU Hardware Detected: {'Yes' if backend_info.get('hardware_gpu_detected', False) else 'No'}" ) print("GPU Available to TensorFlow Runtime: No") if "error" in gpu_info: print(f"Error: {gpu_info['error']}") if backend_info.get("is_apple_silicon", False) and not backend_info.get( "tensorflow_metal_installed", False ): print( "Hint: install tensorflow-metal to enable Metal GPU runtime on Apple Silicon." ) if backend_info: print("\nTensorFlow Backend Diagnostics:") print("-" * 30) print( f"Hardware GPU Detected: {backend_info.get('hardware_gpu_detected', False)}" ) print(f"Runtime Backend: {backend_info.get('runtime_backend', 'cpu')}") print(f"Runtime GPU Count: {backend_info.get('runtime_gpu_count', 0)}") print(f"Apple Silicon: {backend_info.get('is_apple_silicon', False)}") print( f"tensorflow-metal Installed: {backend_info.get('tensorflow_metal_installed', False)}" ) print(f"CUDA Build: {backend_info.get('is_cuda_build', False)}") print(f"ROCm Build: {backend_info.get('is_rocm_build', False)}") print(f"TensorRT Build: {backend_info.get('is_tensorrt_build', False)}") # TensorFlow specific information if TF_AVAILABLE: print("\nTensorFlow Build Information:") print("-" * 30) try: build_info = tf.sysconfig.get_build_info() print( f"CUDA Build: {backend_info.get('is_cuda_build', build_info.get('is_cuda_build', 'Unknown'))}" ) print(f"CUDA Version: {build_info.get('cuda_version', 'Unknown')}") print(f"cuDNN Version: {build_info.get('cudnn_version', 'Unknown')}") except Exception as e: print(f"Could not get build info: {e}") return 0
[docs] def cmd_monitor(args: argparse.Namespace) -> int: """Monitor GPU memory usage in real-time.""" if not TF_AVAILABLE: print("Error: TensorFlow not available") return 1 print("Starting TensorFlow memory monitoring...") print(f"Sampling interval: {args.interval} seconds") print( f"Duration: {args.duration} seconds" if args.duration else "Duration: Indefinite" ) if args.threshold: print(f"Alert threshold: {args.threshold} MB") print("Press Ctrl+C to stop\n") tracker = MemoryTracker( sampling_interval=args.interval, alert_threshold_mb=args.threshold, device=args.device, enable_logging=True, ) try: tracker.start_tracking() start_time = time.time() while True: if args.duration and (time.time() - start_time) >= args.duration: break current_memory = tracker.get_current_memory() print( f"\rCurrent memory usage: {current_memory:.1f} MB", end="", flush=True ) time.sleep(1.0) # Update display every second except KeyboardInterrupt: print("\n\nStopping monitoring...") finally: results = tracker.stop_tracking() print("\nMonitoring Results:") print("-" * 20) print(f"Peak Memory: {results.peak_memory:.1f} MB") print(f"Average Memory: {results.average_memory:.1f} MB") print(f"Duration: {results.duration:.1f} seconds") print(f"Samples Collected: {len(results.memory_usage)}") dropped_samples = int(getattr(results, "history_dropped_samples", 0)) if dropped_samples: print(f"Dropped Samples: {dropped_samples}") if results.alerts_triggered: print(f"Alerts Triggered: {len(results.alerts_triggered)}") if args.output: # Save results output_data = { "peak_memory": results.peak_memory, "average_memory": results.average_memory, "duration": results.duration, "memory_usage": results.memory_usage, "timestamps": results.timestamps, "alerts": results.alerts_triggered, "history_window_limit": int( getattr(results, "history_window_limit", len(results.memory_usage)) ), "history_retained_samples": int( getattr( results, "history_retained_samples", len(results.memory_usage) ) ), "history_dropped_samples": int( getattr(results, "history_dropped_samples", 0) ), "history_retained_alerts": int( getattr( results, "history_retained_alerts", len(results.alerts_triggered), ) ), "history_dropped_alerts": int( getattr(results, "history_dropped_alerts", 0) ), } output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as f: json.dump(output_data, f, indent=2) print(f"Results saved to {args.output}") return 0
[docs] def cmd_track(args: argparse.Namespace) -> int: """Start background memory tracking.""" if not TF_AVAILABLE: print("Error: TensorFlow not available") return 1 wandb_config = _resolve_wandb_config(args) if wandb_config is None: return 1 print("Starting background memory tracking...") job_id = getattr(args, "job_id", None) rank = getattr(args, "rank", None) local_rank = getattr(args, "local_rank", None) world_size = getattr(args, "world_size", None) telemetry_sink_config = _build_telemetry_sink_config(args) tracker = MemoryTracker( sampling_interval=args.interval, alert_threshold_mb=args.threshold, device=args.device, enable_logging=True, job_id=job_id, rank=rank, local_rank=local_rank, world_size=world_size, telemetry_sink_config=telemetry_sink_config, ) if telemetry_sink_config is not None: print(f"Append-only telemetry sink: {telemetry_sink_config.root_dir}") # Add alert callback def alert_callback(alert: Dict[str, Any]) -> None: print(f"\n⚠️ MEMORY ALERT: {alert['message']}") tracker.add_alert_callback(alert_callback) try: tracker.start_tracking() print("Tracking started. Press Ctrl+C to stop and save results.") while True: time.sleep(5.0) # Check every 5 seconds # Show periodic updates stats = tracker.get_statistics() current_memory = stats.get("current_memory_mb") collector_health = str(stats.get("collector_health_status", "healthy")) if isinstance(current_memory, (int, float)): status_line = f"Current memory: {float(current_memory):.1f} MB" else: status_line = "Current memory: unavailable" status_line += f" | Health: {collector_health}" retry_at = stats.get("collector_next_retry_epoch_s") if isinstance(retry_at, (int, float)): retry_in = max(float(retry_at) - time.time(), 0.0) status_line += f" | Retry In: {retry_in:.1f}s" print(status_line) except KeyboardInterrupt: print("\nStopping tracking...") finally: results = tracker.stop_tracking() if args.output: sampling_interval_ms = int(round(args.interval * 1000)) output_data = { "peak_memory": results.peak_memory, "average_memory": results.average_memory, "duration": results.duration, "memory_usage": results.memory_usage, "timestamps": results.timestamps, "alerts": results.alerts_triggered, "events": _normalize_telemetry_events( results.events, sampling_interval_ms=sampling_interval_ms, ), "history_window_limit": int( getattr(results, "history_window_limit", len(results.memory_usage)) ), "history_retained_samples": int( getattr( results, "history_retained_samples", len(results.memory_usage) ) ), "history_dropped_samples": int( getattr(results, "history_dropped_samples", 0) ), "history_retained_events": int( getattr(results, "history_retained_events", len(results.events)) ), "history_dropped_events": int( getattr(results, "history_dropped_events", 0) ), "history_retained_alerts": int( getattr( results, "history_retained_alerts", len(results.alerts_triggered), ) ), "history_dropped_alerts": int( getattr(results, "history_dropped_alerts", 0) ), } output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as f: json.dump(output_data, f, indent=2) print(f"Results saved to {args.output}") final_stats = tracker.get_statistics() print(f"\nTracking completed. Peak memory: {results.peak_memory:.1f} MB") dropped_samples = int(final_stats.get("history_dropped_samples", 0)) dropped_events = int(final_stats.get("history_dropped_events", 0)) if dropped_samples or dropped_events: print( "History truncation: " f"samples={dropped_samples}, events={dropped_events}" ) if "collector_health_status" in final_stats: print( "Collector health: " f"{final_stats.get('collector_health_status', 'healthy')}" ) if final_stats.get("collector_last_error"): print(f"Last collector error: {final_stats.get('collector_last_error')}") if wandb_config.enabled: try: export_tracking_run_to_wandb( wandb_config, command_name="tfmemprof-track", session_summary=tracker.get_session_summary(), stats=final_stats, events=results.events, output_path=args.output, telemetry_sink_dir=getattr(args, "telemetry_sink_dir", None), oom_dump_path=None, ) print("W&B export completed.") except Exception as exc: _warn_wandb_export_failure("tfmemprof track", exc) return 0
[docs] def cmd_analyze(args: argparse.Namespace) -> int: """Analyze profiling results.""" if not args.input: print("Error: Input file required for analysis") return 1 if not Path(args.input).exists(): print(f"Error: Input file {args.input} not found") return 1 print(f"Analyzing results from {args.input}...") # Load results with open(args.input, "r") as f: data = json.load(f) # Create a simple result object for analysis class AnalysisResult: def __init__(self, data: Dict[str, Any]) -> None: memory_usage = data.get("memory_usage") or [] self.peak_memory_mb = data.get("peak_memory", 0) self.average_memory_mb = data.get("average_memory", 0) self.min_memory_mb = min(memory_usage, default=0) self.total_allocations = len( [ m for i, m in enumerate(memory_usage) if i > 0 and m > memory_usage[i - 1] ] ) self.total_deallocations = len( [ m for i, m in enumerate(memory_usage) if i > 0 and m < memory_usage[i - 1] ] ) self.duration = data.get("duration", 0) # Create fake snapshots for analysis self.snapshots = [] timestamps = data.get("timestamps", list(range(len(memory_usage)))) if len(memory_usage) != len(timestamps): raise ValueError( "Invalid input: 'memory_usage' and 'timestamps' must have equal length" ) for i, (mem, ts) in enumerate(zip(memory_usage, timestamps, strict=True)): snapshot = type( "Snapshot", (), { "timestamp": ts, "name": f"sample_{i}", "gpu_memory_mb": mem, "cpu_memory_mb": 0, "gpu_memory_reserved_mb": mem * 1.1, # Estimate "gpu_utilization": min(100, mem / 1000 * 100), "num_tensors": 0, }, )() self.snapshots.append(snapshot) try: result = AnalysisResult(data) except ValueError as exc: print(f"Error: {exc}") return 1 # Basic analysis print("\nBasic Analysis:") print("-" * 15) print(f"Peak Memory: {format_memory(result.peak_memory_mb * 1024 * 1024)}") print(f"Average Memory: {format_memory(result.average_memory_mb * 1024 * 1024)}") print(f"Duration: {result.duration:.2f} seconds") print(f"Memory Allocations: {result.total_allocations}") print(f"Memory Deallocations: {result.total_deallocations}") if args.detect_leaks: print("\nMemory Leak Analysis:") print("-" * 22) analyzer = MemoryAnalyzer() # Create tracking result for leak detection class TrackingResult: def __init__(self, data: Dict[str, Any]) -> None: self.memory_usage = data.get("memory_usage", []) self.timestamps = data.get("timestamps", []) self.memory_growth_rate = 0 if len(self.memory_usage) > 1 and result.duration > 0: self.memory_growth_rate = ( self.memory_usage[-1] - self.memory_usage[0] ) / result.duration tracking_result = TrackingResult(data) leaks = analyzer.detect_memory_leaks(tracking_result) if leaks: print("⚠️ Potential memory leaks detected:") for leak in leaks: print( f" - {leak['type']}: {leak['description']} (Severity: {leak['severity']})" ) else: print("✅ No memory leaks detected") if args.optimize: print("\nOptimization Analysis:") print("-" * 22) analyzer = MemoryAnalyzer() optimization = analyzer.score_optimization(result) print(f"Overall Score: {optimization['overall_score']:.1f}/10") print("\nCategory Scores:") for category, score in optimization["categories"].items(): print(f" {category}: {score:.1f}/10") if optimization["top_recommendations"]: print("\nTop Recommendations:") for i, rec in enumerate(optimization["top_recommendations"], 1): print(f" {i}. {rec}") if args.visualize: print("\nGenerating visualizations...") visualizer = MemoryVisualizer() try: visualizer.plot_memory_timeline(result, save_path="memory_timeline.png") print("✅ Timeline plot saved as memory_timeline.png") except Exception as e: print(f"❌ Could not generate timeline plot: {e}") if args.report: print("\nGenerating comprehensive report...") report = generate_summary_report(result) with open(args.report, "w") as f: f.write(report) print(f"✅ Report saved to {args.report}") return 0
[docs] def cmd_diagnose(args: argparse.Namespace) -> int: """Produce a portable diagnostic bundle. Returns 0 (OK), 1 (failure), or 2 (memory risk).""" if args.duration < 0: print("Error: --duration must be >= 0", file=sys.stderr) return 1 if args.interval <= 0: print("Error: --interval must be > 0", file=sys.stderr) return 1 wandb_config = _resolve_wandb_config(args) if wandb_config is None: return 1 command_line = " ".join(sys.argv) try: artifact_dir, exit_code = run_diagnose( output=args.output, device=args.device, duration=args.duration, interval=args.interval, command_line=command_line, ) except OSError: return 1 # Structured stdout summary print(f"Artifact: {artifact_dir}") if exit_code == 0: status = "OK" elif exit_code == 2: status = "MEMORY_RISK" else: status = "FAILED" print(f"Status: {status} (exit_code={exit_code})") try: manifest_path = artifact_dir / "manifest.json" if manifest_path.exists(): with open(manifest_path) as f: manifest = json.load(f) if manifest.get("risk_detected"): summary_path = artifact_dir / "diagnostic_summary.json" if summary_path.exists(): with open(summary_path) as f: summary = json.load(f) flags = summary.get("risk_flags", {}) parts = [k for k, v in flags.items() if v] if parts: print(f"Findings: {', '.join(parts)}") if exit_code == 0 and status == "OK": print("Findings: no memory risk detected") except (OSError, json.JSONDecodeError): pass if wandb_config.enabled: try: export_diagnose_bundle_to_wandb( wandb_config, command_name="tfmemprof-diagnose", artifact_dir=artifact_dir, ) print("W&B export completed.") except Exception as exc: _warn_wandb_export_failure("tfmemprof diagnose", exc) return exit_code
[docs] def main() -> int: """Main CLI entry point.""" parser = argparse.ArgumentParser( description="TensorFlow Stormlog CLI", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Cookbook: https://stormlog.readthedocs.io/en/latest/cookbook/index.html """, ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose logging" ) subparsers = parser.add_subparsers(dest="command", help="Available commands") # Info command _info_parser = subparsers.add_parser( "info", help="Display system and GPU information" ) # Monitor command monitor_parser = subparsers.add_parser("monitor", help="Monitor GPU memory usage") monitor_parser.add_argument( "--interval", type=float, default=1.0, help="Sampling interval in seconds (default: 1.0)", ) monitor_parser.add_argument( "--duration", type=float, help="Monitoring duration in seconds (default: indefinite)", ) monitor_parser.add_argument( "--threshold", type=float, help="Memory alert threshold in MB" ) monitor_parser.add_argument( "--device", default="/GPU:0", help="TensorFlow device to monitor (default: /GPU:0)", ) monitor_parser.add_argument("--output", help="Output file for results") # Track command track_parser = subparsers.add_parser("track", help="Background memory tracking") track_parser.add_argument( "--interval", type=float, default=1.0, help="Sampling interval in seconds (default: 1.0)", ) track_parser.add_argument( "--threshold", type=float, default=4000, help="Memory alert threshold in MB (default: 4000)", ) track_parser.add_argument( "--device", default="/GPU:0", help="TensorFlow device to monitor (default: /GPU:0)", ) track_parser.add_argument( "--job-id", default=None, help="Distributed job identifier override (default: infer from env)", ) track_parser.add_argument( "--rank", type=int, default=None, help="Global distributed rank override (default: infer from env)", ) track_parser.add_argument( "--local-rank", type=int, default=None, help="Local distributed rank override (default: infer from env)", ) track_parser.add_argument( "--world-size", type=int, default=None, help="Distributed world size override (default: infer from env)", ) track_parser.add_argument( "--output", required=True, help="Output file for tracking results" ) track_parser.add_argument( "--telemetry-sink-dir", default=None, help="Directory for append-only telemetry sink segments", ) track_parser.add_argument( "--telemetry-flush-seconds", type=float, default=2.0, help="Maximum seconds between telemetry sink flushes (default: 2.0)", ) track_parser.add_argument( "--telemetry-rollover-mb", type=int, default=64, help="Telemetry sink segment rollover size in MB (default: 64)", ) track_parser.add_argument( "--telemetry-retention-files", type=int, default=8, help="Maximum retained telemetry sink segments (default: 8)", ) track_parser.add_argument( "--telemetry-retention-total-mb", type=int, default=512, help="Maximum retained telemetry sink size in MB (default: 512)", ) add_wandb_arguments(track_parser) # Analyze command analyze_parser = subparsers.add_parser("analyze", help="Analyze profiling results") analyze_parser.add_argument( "--input", required=True, help="Input file with profiling results" ) analyze_parser.add_argument( "--detect-leaks", action="store_true", help="Detect memory leaks" ) analyze_parser.add_argument( "--optimize", action="store_true", help="Generate optimization recommendations" ) analyze_parser.add_argument( "--visualize", action="store_true", help="Generate visualization plots" ) analyze_parser.add_argument("--report", help="Generate comprehensive report file") # Diagnose command diagnose_parser = subparsers.add_parser( "diagnose", help="Produce a portable diagnostic bundle for debugging memory failures", ) diagnose_parser.add_argument( "--output", type=str, default=None, help="Output directory for the artifact bundle (default: cwd)", ) diagnose_parser.add_argument( "--device", type=str, default="/GPU:0", help="TensorFlow device to monitor (default: /GPU:0)", ) diagnose_parser.add_argument( "--duration", type=float, default=5.0, help="Seconds to run tracker for telemetry (default: 5, use 0 to skip)", ) diagnose_parser.add_argument( "--interval", type=float, default=0.5, help="Sampling interval for timeline (default: 0.5)", ) add_wandb_arguments(diagnose_parser) args = parser.parse_args() setup_logging(args.verbose) if not args.command: parser.print_help() return 0 # Execute command if args.command == "info": return cmd_info(args) elif args.command == "monitor": return cmd_monitor(args) elif args.command == "track": return cmd_track(args) elif args.command == "analyze": return cmd_analyze(args) elif args.command == "diagnose": return cmd_diagnose(args) else: print(f"Unknown command: {args.command}") return 1
if __name__ == "__main__": sys.exit(main())