stormlog.jax.visualizer

JAX Memory Visualization

Classes

MemoryVisualizer([style, figure_size])

JAX memory visualization and dashboards.

class stormlog.jax.visualizer.MemoryVisualizer(style='default', figure_size=(12, 8))[source]

Bases: object

JAX memory visualization and dashboards.

Parameters:
  • style (str)

  • figure_size (Tuple[int, int])

plot_memory_timeline(results, interactive=False, save_path=None)[source]

Plot device memory usage timeline.

Parameters:
  • results (Any)

  • interactive (bool)

  • save_path (str | None)

Return type:

None

plot_function_comparison(function_profiles, save_path=None)[source]

Plot memory usage comparison for functions/contexts.

Parameters:
  • function_profiles (Dict[str, Dict[str, Any]])

  • save_path (str | None)

Return type:

None