Skip to content

Commit fb96299

Browse files
Merge pull request #3051 from AI-Hypercomputer:xfgu-storage-fix
PiperOrigin-RevId: 863330304
2 parents 31d0b8c + 345cd0b commit fb96299

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def _load_full_state_from_path(
143143
enable_orbax_v1,
144144
checkpoint_conversion_fn,
145145
source_checkpoint_layout,
146+
checkpoint_storage_concurrent_gb,
147+
use_ocdbt,
148+
use_zarr3,
146149
):
147150
"""Load full state from checkpoint at specified path.
148151
@@ -155,6 +158,9 @@ def _load_full_state_from_path(
155158
maxtext-supported state.
156159
source_checkpoint_layout: String representation of the checkpoint layout of
157160
the source checkpoint.
161+
checkpoint_storage_concurrent_gb: concurrent GB for checkpoint byte I/O.
162+
use_ocdbt: Whether to use OCDBT format.
163+
use_zarr3: Whether to use Zarr3 format.
158164
159165
Returns:
160166
The loaded state.
@@ -184,7 +190,13 @@ def combine_sharding(sds, shardings):
184190
else:
185191
# Original v0 logic.
186192
p = epath.Path(path)
187-
return ocp.StandardCheckpointer().restore(p, abstract_unboxed_pre_state)
193+
handler = ocp.PyTreeCheckpointHandler(
194+
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
195+
save_concurrent_gb=checkpoint_storage_concurrent_gb,
196+
use_ocdbt=use_ocdbt,
197+
use_zarr3=use_zarr3,
198+
)
199+
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state)
188200

189201

190202
def create_orbax_checkpoint_manager(
@@ -198,6 +210,7 @@ def create_orbax_checkpoint_manager(
198210
use_zarr3: bool = True,
199211
enable_continuous_checkpointing: bool = False,
200212
max_num_checkpoints_to_keep: int = 10,
213+
checkpoint_storage_concurrent_gb: int = 96,
201214
):
202215
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
203216
if not enable_checkpointing:
@@ -209,7 +222,14 @@ def create_orbax_checkpoint_manager(
209222
# Base configuration for all dataset types
210223
item_names = ("items",)
211224
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
212-
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
225+
item_handlers = {
226+
"items": PyTreeCheckpointHandler(
227+
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
228+
save_concurrent_gb=checkpoint_storage_concurrent_gb,
229+
use_ocdbt=use_ocdbt,
230+
use_zarr3=use_zarr3,
231+
)
232+
}
213233

214234
if dataset_type == "grain":
215235
item_names += ("iter",)
@@ -596,6 +616,9 @@ def map_to_pspec(data):
596616
enable_orbax_v1=enable_orbax_v1,
597617
checkpoint_conversion_fn=checkpoint_conversion_fn,
598618
source_checkpoint_layout=source_checkpoint_layout,
619+
checkpoint_storage_concurrent_gb=checkpoint_storage_concurrent_gb,
620+
use_ocdbt=use_ocdbt,
621+
use_zarr3=use_zarr3,
599622
)
600623
return {"items": restored_state}, None
601624
else:

src/maxtext/utils/train_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def create_training_tools(config, model, mesh):
7575
use_zarr3,
7676
config.enable_continuous_checkpointing,
7777
config.max_num_checkpoints_to_keep,
78+
config.checkpoint_storage_concurrent_gb,
7879
)
7980

8081
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)