Skip to content

Commit c6f3bc2

Browse files
Merge pull request #3192 from kryvokhyzha:fix/to-hf-oom-error
PiperOrigin-RevId: 872564224
2 parents 3fbe3be + 389e383 commit c6f3bc2

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,6 @@ def main(argv: Sequence[str]) -> None:
200200
Args:
201201
argv: Command-line arguments, which are parsed by `pyconfig`.
202202
"""
203-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
204-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
205-
206-
jax.config.update("jax_platforms", "cpu")
207-
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
208-
209203
# Initialize maxtext config
210204
config = pyconfig.initialize(argv)
211205
assert (
@@ -298,4 +292,10 @@ def main(argv: Sequence[str]) -> None:
298292

299293

300294
if __name__ == "__main__":
295+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
296+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
297+
298+
jax.config.update("jax_platforms", "cpu")
299+
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=1"
300+
301301
app.run(main)

0 commit comments

Comments
 (0)