Skip to content

Commit f215b60

Browse files
lyglstGoogle-ML-Automation
authored andcommitted
Add continuous checkpoint option in MaxText.
PiperOrigin-RevId: 845008539
1 parent 2b55eb0 commit f215b60

4 files changed

Lines changed: 31 additions & 4 deletions

File tree

src/MaxText/checkpointing.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import orbax.checkpoint as ocp
3131
from orbax.checkpoint import v1 as ocp_v1
3232
from orbax.checkpoint._src.arrays import sharding as sharding_utils
33+
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
34+
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
3335
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3436
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3537
# pylint: disable=too-many-positional-arguments
@@ -194,6 +196,8 @@ def create_orbax_checkpoint_manager(
194196
orbax_logger: Any = None, # pytype: disable=attribute-error
195197
use_ocdbt: bool = True,
196198
use_zarr3: bool = True,
199+
enable_continuous_checkpointing: bool = False,
200+
max_num_checkpoints_to_keep: int = 10,
197201
):
198202
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
199203
if not enable_checkpointing:
@@ -214,15 +218,28 @@ def create_orbax_checkpoint_manager(
214218
# local storage checkpoint needs parent directory created
215219
p = epath.Path(checkpoint_dir)
216220
p.mkdir(exist_ok=True, parents=True)
221+
if enable_continuous_checkpointing:
222+
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
223+
preservation_policy = preservation_policy_lib.LatestN(
224+
max_num_checkpoints_to_keep
225+
)
226+
else:
227+
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(
228+
interval=save_interval_steps
229+
)
230+
preservation_policy = preservation_policy_lib.LatestN(
231+
max_num_checkpoints_to_keep
232+
)
217233
manager = CheckpointManager(
218234
p,
219235
item_names=item_names,
220236
item_handlers=item_handlers,
221237
options=CheckpointManagerOptions(
222238
create=True,
223-
save_interval_steps=save_interval_steps,
224239
enable_async_checkpointing=use_async,
225-
),
240+
save_decision_policy=save_decision_policy,
241+
preservation_policy=preservation_policy,
242+
),
226243
logger=orbax_logger,
227244
)
228245

@@ -259,8 +276,12 @@ def create_orbax_emergency_checkpoint_manager(
259276
global_mesh=global_mesh,
260277
abstract_state=abstract_state,
261278
options=emergency_checkpoint_manager.CheckpointManagerOptions(
262-
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
263-
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
279+
local=LocalCheckpointOptions(
280+
save_interval_steps=local_save_interval_steps
281+
),
282+
persistent=PersistentCheckpointOptions(
283+
save_interval_steps=persistent_save_interval_steps
284+
),
264285
),
265286
logger=orbax_logger,
266287
)

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ save_checkpoint_on_completion: True
5252
async_checkpointing: True
5353
checkpoint_period: 10_000
5454
max_num_checkpoints_to_keep: None
55+
enable_continuous_checkpointing: False
5556
# enables one replica to read the ckpt then broadcast to the rest
5657
enable_single_replica_ckpt_restoring: False
5758

src/MaxText/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ class Checkpointing(BaseModel):
279279
save_checkpoint_on_completion: bool = Field(
280280
True, description="If True, saves a final checkpoint upon training completion."
281281
)
282+
enable_continuous_checkpointing: bool = Field(
283+
False, description="If True, enables continuous checkpointing."
284+
)
282285

283286

284287
class OrbaxStorage(BaseModel):

src/MaxText/train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def create_training_tools(config, model, mesh):
7373
logger,
7474
use_ocdbt,
7575
use_zarr3,
76+
config.enable_continuous_checkpointing,
77+
config.max_num_checkpoints_to_keep,
7678
)
7779

7880
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)