We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 36242d2 + e805034 commit 8fc3626Copy full SHA for 8fc3626
1 file changed
src/maxdiffusion/max_utils.py
@@ -402,7 +402,10 @@ def setup_initial_state(
402
config.enable_single_replica_ckpt_restoring,
403
)
404
if state:
405
- state = state[checkpoint_item]
+ if checkpoint_item == " ":
406
+ state = state
407
+ else:
408
+ state = state[checkpoint_item]
409
if not state:
410
max_logging.log(f"Could not find the item in orbax, creating state...")
411
init_train_state_partial = functools.partial(
@@ -609,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
609
612
initialize_jax_for_gpu()
610
613
max_logging.log("Jax distributed system initialized on GPU!")
611
614
else:
- jax.distributed.initialize()
615
+ jax.distributed.initialize()
0 commit comments