Skip to content

Commit 389e383

Browse files
committed
fix: resolve OOM in to_huggingface.py by setting simulated_cpu_devices_count to 1
1 parent f70f5c8 commit 389e383

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,6 @@ def main(argv: Sequence[str]) -> None:
200200
Args:
201201
argv: Command-line arguments, which are parsed by `pyconfig`.
202202
"""
203-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
204-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
205-
206-
jax.config.update("jax_platforms", "cpu")
207-
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
208-
209203
# Initialize maxtext config
210204
config = pyconfig.initialize(argv)
211205
assert (
@@ -298,4 +292,10 @@ def main(argv: Sequence[str]) -> None:
298292

299293

300294
if __name__ == "__main__":
295+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
296+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
297+
298+
jax.config.update("jax_platforms", "cpu")
299+
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=1"
300+
301301
app.run(main)

0 commit comments

Comments
 (0)