stormlog.jax.analyzer
JAX Memory Analysis.
Advanced analysis tools for JAX memory profiling data. Provides feature
parity with the PyTorch and TensorFlow analyzer modules while adapting
recommendations and heuristics to JAX-specific idioms (XLA compilation,
jax.jit, device memory pools, etc.).
Classes
|
Advanced analyzer for JAX memory profiling data. |
- class stormlog.jax.analyzer.MemoryAnalyzer(sensitivity=0.05, collective_sensitivity='medium', collective_threshold_overrides=None)[source]
Bases:
objectAdvanced analyzer for JAX memory profiling data.
Mirrors the public interface of the PyTorch and TensorFlow
MemoryAnalyzerclasses, adapting heuristics and recommendations for JAX’s XLA runtime and device memory model.- Parameters:
sensitivity (float)
collective_sensitivity (str)
collective_threshold_overrides (Optional[Mapping[str, Any]])
- detect_memory_leaks(results)[source]
Detect potential memory leaks in JAX telemetry.
Uses linear regression over the memory-usage series to detect sustained upward drift.
- Parameters:
results (Any) – Object with a
memory_usageattribute (numeric sequence of at least 10 samples).- Returns:
List of leak-detection dicts with
type,severity,description, andslopekeys.- Return type:
List[Dict[str, Any]]
- detect_patterns(results)[source]
Detect allocation patterns in JAX telemetry.
Performs simplified autocorrelation analysis to identify periodic memory usage behaviour (e.g. per-step allocations in a training loop).
- Parameters:
results (Any) – Object with a
memory_usageattribute (numeric sequence of at least 10 samples).- Returns:
List of detected pattern dicts.
- Return type:
List[Dict[str, Any]]
- analyze_fragmentation(profile_result)[source]
Analyse memory fragmentation patterns.
Computes fragmentation as
1 − (used / reserved)across profiling snapshots.- Parameters:
profile_result (Any) – Profiling result with a
snapshotsattribute where each snapshot exposesdevice_memory_mbanddevice_memory_reserved_mb.- Returns:
Dictionary with
fragmentation_score,trend,max_fragmentation, andmin_fragmentation.- Return type:
Dict[str, float]
- analyze_efficiency(profile_result)[source]
Analyse memory usage efficiency.
Returns a score on a 0.0–1.0 scale (1.0 = excellent). The score starts at 1.0 and is reduced by penalties for high peak memory, high growth rate, fragmentation, and detected leaks.
- Parameters:
profile_result (Any) – Profiling result with
peak_memory_mband optionallymemory_growth_rate,snapshots, andmemory_usage.- Returns:
Efficiency score in [0.0, 1.0].
- Return type:
float
- correlate_with_performance(profile_result)[source]
Correlate memory usage with performance metrics.
Analyses per-function efficiency based on memory consumption and execution duration.
- Parameters:
profile_result (Any) – Profiling result with
function_profilesmapping function names to dicts containingcalls,total_memory_delta, andtotal_duration.- Returns:
Dictionary with
memory_duration_correlation,function_efficiency, andrecommendations.- Return type:
Dict[str, Any]
- score_optimization(profile_result, events=None)[source]
Generate an overall optimisation score with recommendations.
Combines memory efficiency, fragmentation, and per-function performance scores into a single summary.
- Parameters:
profile_result (Any) – JAX profiling result object.
events (List | None) – Optional telemetry event series for gap analysis. When provided, the result includes
gap_analysisandcollective_attributionsections.
- Returns:
Dictionary with
overall_score,categories,top_recommendations, andpriority_actions.- Return type:
Dict[str, Any]
- analyze_memory_gaps(events, *, phase_resolver=None)[source]
Classify allocator-vs-device hidden memory gaps over time.
- Parameters:
events (List) – Chronologically ordered telemetry samples.
phase_resolver (Any | None) – Optional
PhaseReplayIndexfor phase attribution.
- Returns:
Prioritised list of gap findings (severity desc, confidence desc). Returns an empty list when the
gap_analysissub-package is not available.- Return type:
List
- analyze_collective_attribution(events, *, phase_resolver=None)[source]
Attribute hidden-memory spikes to collective communication phases.
- Parameters:
events (List) – Chronologically ordered telemetry samples.
phase_resolver (Any | None) – Optional
PhaseReplayIndexfor phase attribution.
- Returns:
List of
CollectiveAttributionResultobjects. Returns an empty list when thecollective_attributionsub-package is not available.- Return type:
List