Skip to content

Commit 8fc3626

Browse files
committed
Merge branch 'conversion-script' of https://github.com/AI-Hypercomputer/maxdiffusion into conversion-script
2 parents 36242d2 + e805034 commit 8fc3626

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,10 @@ def setup_initial_state(
402402
config.enable_single_replica_ckpt_restoring,
403403
)
404404
if state:
405-
state = state[checkpoint_item]
405+
if checkpoint_item == " ":
406+
state = state
407+
else:
408+
state = state[checkpoint_item]
406409
if not state:
407410
max_logging.log(f"Could not find the item in orbax, creating state...")
408411
init_train_state_partial = functools.partial(
@@ -609,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
609612
initialize_jax_for_gpu()
610613
max_logging.log("Jax distributed system initialized on GPU!")
611614
else:
612-
jax.distributed.initialize()
615+
jax.distributed.initialize()

0 commit comments

Comments
 (0)