Source code for stormlog.tensorflow.context_profiler

"""TensorFlow Context Profiling"""

import functools
import threading
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union, cast

from .tf_env import configure_tensorflow_logging

configure_tensorflow_logging()

try:
    import tensorflow as tf

    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    tf = None

from .profiler import TFMemoryProfiler

# Global profiler instance
_global_profiler: Optional[TFMemoryProfiler] = None
_profiler_lock = threading.Lock()
F = TypeVar("F", bound=Callable[..., Any])


[docs] def get_global_profiler() -> TFMemoryProfiler: """Get or create global profiler instance.""" global _global_profiler with _profiler_lock: if _global_profiler is None: _global_profiler = TFMemoryProfiler() return _global_profiler
[docs] def set_global_profiler(profiler: TFMemoryProfiler) -> None: """Set global profiler instance.""" global _global_profiler with _profiler_lock: _global_profiler = profiler
[docs] def profile_function( func: Optional[F] = None, *, profiler: Optional[TFMemoryProfiler] = None, name: Optional[str] = None, ) -> Union[Callable[[F], F], F]: """ Decorator to profile function memory usage. Args: func: Function to profile profiler: Profiler instance (uses global if None) name: Custom name for profiling """ def decorator(f: F) -> F: profiled_name = name or f.__name__ @functools.wraps(f) def profiled_target(*args: Any, **kwargs: Any) -> Any: return f(*args, **kwargs) profiled_target.__name__ = profiled_name @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: prof = profiler or get_global_profiler() return prof.profile_function(profiled_target)(*args, **kwargs) return cast(F, wrapper) if func is None: return decorator else: return decorator(func)
[docs] @contextmanager def profile_context( name: str = "context", profiler: Optional[TFMemoryProfiler] = None ) -> Iterator[None]: """ Context manager for profiling code blocks. Args: name: Name for the profiling context profiler: Profiler instance (uses global if None) """ prof = profiler or get_global_profiler() with prof.profile_context(name): yield
[docs] class ProfiledLayer: """Wrapper for TensorFlow layers with automatic profiling.""" def __init__( self, layer: Any, profiler: Optional[TFMemoryProfiler] = None, name: Optional[str] = None, ) -> None: """ Initialize profiled layer. Args: layer: TensorFlow layer to profile profiler: Profiler instance name: Custom name for profiling """ if not TF_AVAILABLE: raise ImportError("TensorFlow not available") self.layer = layer self.profiler = profiler or get_global_profiler() self.name = name or getattr(layer, "name", layer.__class__.__name__) # Wrap the call method self._original_call = layer.call layer.call = self._profiled_call def _profiled_call(self, *args: Any, **kwargs: Any) -> Any: """Profiled version of layer call.""" with self.profiler.profile_context(f"layer_{self.name}"): return self._original_call(*args, **kwargs) def __getattr__(self, name: str) -> Any: """Delegate attribute access to wrapped layer.""" return getattr(self.layer, name) def __call__(self, *args: Any, **kwargs: Any) -> Any: """Make the wrapper callable.""" return self.layer(*args, **kwargs)
[docs] def profile_model(model: Any, profiler: Optional[TFMemoryProfiler] = None) -> Any: """ Profile all layers in a TensorFlow model. Args: model: TensorFlow model profiler: Profiler instance Returns: Model with profiled layers """ if not TF_AVAILABLE: raise ImportError("TensorFlow not available") prof = profiler or get_global_profiler() # Profile each layer for i, layer in enumerate(model.layers): ProfiledLayer(layer, prof, f"{layer.name}_{i}") return model
[docs] class TensorFlowProfiler: """High-level TensorFlow profiling interface.""" def __init__(self, device: Optional[str] = None) -> None: """Initialize TensorFlow profiler.""" self.profiler = TFMemoryProfiler(device=device) set_global_profiler(self.profiler)
[docs] def profile_training( self, model: Any, dataset: Any, epochs: int = 1, steps_per_epoch: Optional[int] = None, ) -> None: """ Profile model training. Args: model: TensorFlow model dataset: Training dataset epochs: Number of epochs steps_per_epoch: Steps per epoch """ if not TF_AVAILABLE: raise ImportError("TensorFlow not available") # Profile the entire training process with self.profiler.profile_context("training"): for epoch in range(epochs): with self.profiler.profile_context(f"epoch_{epoch}"): step_count = 0 for batch in dataset: if steps_per_epoch and step_count >= steps_per_epoch: break with self.profiler.profile_context(f"step_{step_count}"): # Assume the model has a train_step method or similar if hasattr(model, "train_step"): model.train_step(batch) else: # Generic training step with tf.GradientTape() as tape: if isinstance(batch, tuple): x, y = batch predictions = model(x, training=True) loss = model.compiled_loss(y, predictions) else: predictions = model(batch, training=True) loss = model.compiled_loss(batch, predictions) gradients = tape.gradient( loss, model.trainable_variables ) model.optimizer.apply_gradients( zip(gradients, model.trainable_variables) ) step_count += 1
[docs] def profile_inference(self, model: Any, data: Any, batch_size: int = 32) -> None: """ Profile model inference. Args: model: TensorFlow model data: Input data batch_size: Batch size for inference """ if not TF_AVAILABLE: raise ImportError("TensorFlow not available") # Batch the data if needed if hasattr(data, "batch"): batched_data = data.batch(batch_size) for i, batch in enumerate(batched_data): with self.profiler.profile_context(f"inference_batch_{i}"): model(batch, training=False) return with self.profiler.profile_context("inference"): # Assume data is a tensor or numpy array import numpy as np if isinstance(data, np.ndarray): data = tf.constant(data) # Create batches manually num_samples = tf.shape(data)[0] num_batches = (num_samples + batch_size - 1) // batch_size for i in range(num_batches): start_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_samples) batch = data[start_idx:end_idx] with self.profiler.profile_context(f"inference_batch_{i}"): model(batch, training=False)
[docs] def get_results(self) -> Any: """Get profiling results.""" return self.profiler.get_results()
[docs] def reset(self) -> None: """Reset profiler state.""" self.profiler.reset()
# Convenience functions for common use cases
[docs] def profile_keras_training( model: Any, x_train: Any, y_train: Any, epochs: int = 1, batch_size: int = 32, validation_data: Optional[Any] = None, profiler: Optional[TFMemoryProfiler] = None, ) -> None: """ Profile Keras model training. Args: model: Keras model x_train: Training data y_train: Training labels epochs: Number of epochs batch_size: Batch size validation_data: Validation data tuple (x_val, y_val) profiler: Profiler instance """ if not TF_AVAILABLE: raise ImportError("TensorFlow not available") prof = profiler or get_global_profiler() with prof.profile_context("keras_training"): # Create dataset train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.batch(batch_size) # Profile training for epoch in range(epochs): with prof.profile_context(f"epoch_{epoch}"): # Training with prof.profile_context("training_batches"): for batch_x, batch_y in train_dataset: with prof.profile_context("train_step"): model.train_on_batch(batch_x, batch_y) # Validation if validation_data: x_val, y_val = validation_data with prof.profile_context("validation"): model.evaluate(x_val, y_val, verbose=0)
[docs] def clear_global_profiler() -> None: """Clear global profiler state.""" global _global_profiler with _profiler_lock: if _global_profiler: _global_profiler.reset() _global_profiler = None
[docs] def clear_profiles() -> None: """Reset profiling data without discarding the global profiler.""" with _profiler_lock: if _global_profiler: _global_profiler.reset()
[docs] def get_profile_summaries(limit: Optional[int] = None) -> List[Dict[str, Any]]: """Return aggregated profiling summaries for recent functions/contexts.""" with _profiler_lock: profiler = _global_profiler if not profiler or not profiler.function_profiles: return [] entries: List[Dict[str, Any]] = [] for name, stats in profiler.function_profiles.items(): snapshots = stats.get("snapshots") or [] last_snapshot = snapshots[-1] if snapshots else None last_timestamp = getattr(last_snapshot, "timestamp", None) entries.append( { "name": name, "calls": stats.get("calls", 0), "total_duration": stats.get("total_duration", 0.0), "total_memory_used": stats.get("total_memory_used", 0.0), "peak_memory": stats.get("peak_memory", 0.0), "last_timestamp": last_timestamp, } ) entries.sort(key=lambda entry: entry.get("last_timestamp") or 0.0, reverse=True) if limit: return entries[:limit] return entries