File tree Expand file tree Collapse file tree
src/maxtext/checkpoint_conversion Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -200,12 +200,6 @@ def main(argv: Sequence[str]) -> None:
200200 Args:
201201 argv: Command-line arguments, which are parsed by `pyconfig`.
202202 """
203- jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
204- os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
205-
206- jax .config .update ("jax_platforms" , "cpu" )
207- os .environ ["XLA_FLAGS" ] = "--xla_force_host_platform_device_count=16"
208-
209203 # Initialize maxtext config
210204 config = pyconfig .initialize (argv )
211205 assert (
@@ -298,4 +292,10 @@ def main(argv: Sequence[str]) -> None:
298292
299293
300294if __name__ == "__main__" :
295+ jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
296+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
297+
298+ jax .config .update ("jax_platforms" , "cpu" )
299+ os .environ ["XLA_FLAGS" ] = "--xla_force_host_platform_device_count=1"
300+
301301 app .run (main )
You can’t perform that action at this time.
0 commit comments