Skip to content

Commit cdec380

Browse files
committed
Set TF_CPP_MIN_LOG_LEVEL=0 in __init__.py instead of train.py. Setting this env variable after import jax is called has no effect, because the C++ logging system is already initialized.
1 parent 496ed40 commit cdec380

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

src/MaxText/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727

2828
from collections.abc import Sequence
2929

30+
import os
31+
# In order to have any effect on the C++ logging this has to be set before we import anything from jax.
32+
# When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger.
33+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0")
34+
del os
35+
3036
from jax.sharding import Mesh
3137

3238
from MaxText import pyconfig

src/MaxText/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
523523
# TF allocates extraneous GPU memory when using TFDS data
524524
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
525525
tf.config.set_visible_devices([], "GPU")
526-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
527526
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
528527
os.environ["LIBTPU_INIT_ARGS"] = (
529528
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"

0 commit comments

Comments
 (0)