Skip to content

Commit a3c3fd1

Browse files
Merge pull request #3131 from AI-Hypercomputer:igorts-dev
PiperOrigin-RevId: 871361489
2 parents 38ed3ff + cdec380 commit a3c3fd1

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

src/MaxText/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727

2828
from collections.abc import Sequence
2929

30+
import os
31+
# In order to have any effect on the C++ logging this has to be set before we import anything from jax.
32+
# When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger.
33+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0")
34+
del os
35+
3036
from jax.sharding import Mesh
3137

3238
from MaxText import pyconfig

src/MaxText/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,6 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
538538
# TF allocates extraneous GPU memory when using TFDS data
539539
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
540540
tf.config.set_visible_devices([], "GPU")
541-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
542541
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
543542
os.environ["LIBTPU_INIT_ARGS"] = (
544543
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"

0 commit comments

Comments
 (0)