stormlog.jax.utils

Utility functions for JAX memory profiling.

This module provides helper functions for JAX device discovery, memory formatting, system information, and environment validation.

Functions

detect_jax_backend()

Return the active JAX backend name.

format_memory(bytes_value)

Format memory size in human-readable format.

get_backend_info()

Return backend diagnostics for JAX.

get_device_info([device_index])

Return device kind, platform, and live memory statistics.

get_system_info()

Return full system and JAX environment report.

jax_is_available()

Return True when JAX is importable.

validate_jax_environment()

Validate JAX environment for memory profiling.

stormlog.jax.utils.jax_is_available()[source]

Return True when JAX is importable.

Return type:

bool

stormlog.jax.utils.detect_jax_backend()[source]

Return the active JAX backend name.

Returns one of ‘gpu’, ‘tpu’, or ‘cpu’. Returns ‘cpu’ as a fallback if JAX is not installed or backend detection fails.

Return type:

str

stormlog.jax.utils.get_device_info(device_index=0)[source]

Return device kind, platform, and live memory statistics.

Parameters:

device_index (int) – Index into jax.local_devices() (default 0).

Returns:

Dictionary with keys kind, platform, device_id, process_index, memory_stats (raw dict from device.memory_stats()), and client.

Return type:

Dict[str, Any]

stormlog.jax.utils.get_backend_info()[source]

Return backend diagnostics for JAX.

Returns a dictionary with the JAX runtime backend classification and platform details.

Return type:

Dict[str, Any]

stormlog.jax.utils.get_system_info()[source]

Return full system and JAX environment report.

Includes JAX version, device list, platform, Python version, CPU count, and system memory statistics.

Return type:

Dict[str, Any]

stormlog.jax.utils.format_memory(bytes_value)[source]

Format memory size in human-readable format.

Delegates to stormlog.utils.format_bytes() when available, otherwise provides a standalone implementation.

Parameters:

bytes_value (int | float | None)

Return type:

str

stormlog.jax.utils.validate_jax_environment()[source]

Validate JAX environment for memory profiling.

Returns a dictionary with validation results and a list of any issues found.

Return type:

Dict[str, Any]