File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -213,8 +213,11 @@ def load_state_if_possible(
213213 max_logging .log (f"restoring from this run's directory latest step { latest_step } " )
214214 try :
215215 if not enable_single_replica_ckpt_restoring :
216- item = {checkpoint_item : orbax .checkpoint .args .PyTreeRestore (item = abstract_unboxed_pre_state )}
217- return checkpoint_manager .restore (latest_step , args = orbax .checkpoint .args .Composite (** item ))
216+ if checkpoint_item == " " :
217+ return checkpoint_manager .restore (latest_step , args = ocp .args .StandardRestore (abstract_unboxed_pre_state ))
218+ else :
219+ item = {checkpoint_item : orbax .checkpoint .args .PyTreeRestore (item = abstract_unboxed_pre_state )}
220+ return checkpoint_manager .restore (latest_step , args = orbax .checkpoint .args .Composite (** item ))
218221
219222 def map_to_pspec (data ):
220223 pspec = data .sharding .spec
Original file line number Diff line number Diff line change @@ -402,7 +402,10 @@ def setup_initial_state(
402402 config .enable_single_replica_ckpt_restoring ,
403403 )
404404 if state :
405- state = state [checkpoint_item ]
405+ if checkpoint_item == " " :
406+ state = state
407+ else :
408+ state = state [checkpoint_item ]
406409 if not state :
407410 max_logging .log (f"Could not find the item in orbax, creating state..." )
408411 init_train_state_partial = functools .partial (
You can’t perform that action at this time.
0 commit comments