@@ -143,6 +143,9 @@ def _load_full_state_from_path(
143143 enable_orbax_v1 ,
144144 checkpoint_conversion_fn ,
145145 source_checkpoint_layout ,
146+ checkpoint_storage_concurrent_gb ,
147+ use_ocdbt ,
148+ use_zarr3 ,
146149):
147150 """Load full state from checkpoint at specified path.
148151
@@ -155,6 +158,9 @@ def _load_full_state_from_path(
155158 maxtext-supported state.
156159 source_checkpoint_layout: String representation of the checkpoint layout of
157160 the source checkpoint.
161+ checkpoint_storage_concurrent_gb: concurrent GB for checkpoint byte I/O.
162+ use_ocdbt: Whether to use OCDBT format.
163+ use_zarr3: Whether to use Zarr3 format.
158164
159165 Returns:
160166 The loaded state.
@@ -184,7 +190,13 @@ def combine_sharding(sds, shardings):
184190 else :
185191 # Original v0 logic.
186192 p = epath .Path (path )
187- return ocp .StandardCheckpointer ().restore (p , abstract_unboxed_pre_state )
193+ handler = ocp .PyTreeCheckpointHandler (
194+ restore_concurrent_gb = checkpoint_storage_concurrent_gb ,
195+ save_concurrent_gb = checkpoint_storage_concurrent_gb ,
196+ use_ocdbt = use_ocdbt ,
197+ use_zarr3 = use_zarr3 ,
198+ )
199+ return ocp .Checkpointer (handler ).restore (p , abstract_unboxed_pre_state )
188200
189201
190202def create_orbax_checkpoint_manager (
@@ -198,6 +210,7 @@ def create_orbax_checkpoint_manager(
198210 use_zarr3 : bool = True ,
199211 enable_continuous_checkpointing : bool = False ,
200212 max_num_checkpoints_to_keep : int = 10 ,
213+ checkpoint_storage_concurrent_gb : int = 96 ,
201214):
202215 """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
203216 if not enable_checkpointing :
@@ -209,7 +222,14 @@ def create_orbax_checkpoint_manager(
209222 # Base configuration for all dataset types
210223 item_names = ("items" ,)
211224 # we need to use ocdbt and zarr3 to control max file size in the checkpoint
212- item_handlers = {"items" : PyTreeCheckpointHandler (use_ocdbt = use_ocdbt , use_zarr3 = use_zarr3 )}
225+ item_handlers = {
226+ "items" : PyTreeCheckpointHandler (
227+ restore_concurrent_gb = checkpoint_storage_concurrent_gb ,
228+ save_concurrent_gb = checkpoint_storage_concurrent_gb ,
229+ use_ocdbt = use_ocdbt ,
230+ use_zarr3 = use_zarr3 ,
231+ )
232+ }
213233
214234 if dataset_type == "grain" :
215235 item_names += ("iter" ,)
@@ -596,6 +616,9 @@ def map_to_pspec(data):
596616 enable_orbax_v1 = enable_orbax_v1 ,
597617 checkpoint_conversion_fn = checkpoint_conversion_fn ,
598618 source_checkpoint_layout = source_checkpoint_layout ,
619+ checkpoint_storage_concurrent_gb = checkpoint_storage_concurrent_gb ,
620+ use_ocdbt = use_ocdbt ,
621+ use_zarr3 = use_zarr3 ,
599622 )
600623 return {"items" : restored_state }, None
601624 else :
0 commit comments