Source code for stormlog.jax.jax_env
"""JAX environment configuration for Stormlog.
Suppresses verbose JAX/XLA logging and configures the JAX runtime
environment before any ``import jax`` call. Every module in the
``stormlog.jax`` package should call :func:`configure_jax_logging`
at import time, **before** importing ``jax`` itself.
"""
from __future__ import annotations
import os
_CONFIGURED = False
[docs]
def configure_jax_logging() -> None:
"""Suppress verbose JAX/XLA info-level logging.
Idempotent — safe to call multiple times. Sets environment
variables that JAX and XLA inspect on first import:
* ``JAX_LOG_COMPILES`` → ``"0"`` (suppress JIT compilation logs)
* ``TF_CPP_MIN_LOG_LEVEL`` → ``"2"`` (suppress TF C++ backend noise
when JAX falls back to the TF XLA bridge)
"""
global _CONFIGURED
if _CONFIGURED:
return
os.environ.setdefault("JAX_LOG_COMPILES", "0")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
_CONFIGURED = True