Skip to content

Commit bb61ecb

Browse files
committed
functional
1 parent f5afa91 commit bb61ecb

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,11 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
216+
if checkpoint_item == " ":
217+
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218+
else:
219+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
218221

219222
def map_to_pspec(data):
220223
pspec = data.sharding.spec

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,10 @@ def setup_initial_state(
402402
config.enable_single_replica_ckpt_restoring,
403403
)
404404
if state:
405-
state = state[checkpoint_item]
405+
if checkpoint_item == " ":
406+
state = state
407+
else:
408+
state = state[checkpoint_item]
406409
if not state:
407410
max_logging.log(f"Could not find the item in orbax, creating state...")
408411
init_train_state_partial = functools.partial(

0 commit comments

Comments
 (0)