3030import orbax .checkpoint as ocp
3131from orbax .checkpoint import v1 as ocp_v1
3232from orbax .checkpoint ._src .arrays import sharding as sharding_utils
33+ from orbax .checkpoint ._src .checkpoint_managers import preservation_policy as preservation_policy_lib
34+ from orbax .checkpoint ._src .checkpoint_managers import save_decision_policy as save_decision_policy_lib
3335import orbax .checkpoint .experimental .emergency .checkpoint_manager as emergency_checkpoint_manager
3436import orbax .checkpoint .experimental .emergency .replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3537# pylint: disable=too-many-positional-arguments
@@ -194,6 +196,8 @@ def create_orbax_checkpoint_manager(
194196 orbax_logger : Any = None , # pytype: disable=attribute-error
195197 use_ocdbt : bool = True ,
196198 use_zarr3 : bool = True ,
199+ enable_continuous_checkpointing : bool = False ,
200+ max_num_checkpoints_to_keep : int = 10 ,
197201):
198202 """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
199203 if not enable_checkpointing :
@@ -214,15 +218,28 @@ def create_orbax_checkpoint_manager(
214218 # local storage checkpoint needs parent directory created
215219 p = epath .Path (checkpoint_dir )
216220 p .mkdir (exist_ok = True , parents = True )
221+ if enable_continuous_checkpointing :
222+ save_decision_policy = save_decision_policy_lib .ContinuousCheckpointingPolicy ()
223+ preservation_policy = preservation_policy_lib .LatestN (
224+ max_num_checkpoints_to_keep
225+ )
226+ else :
227+ save_decision_policy = save_decision_policy_lib .FixedIntervalPolicy (
228+ interval = save_interval_steps
229+ )
230+ preservation_policy = preservation_policy_lib .LatestN (
231+ max_num_checkpoints_to_keep
232+ )
217233 manager = CheckpointManager (
218234 p ,
219235 item_names = item_names ,
220236 item_handlers = item_handlers ,
221237 options = CheckpointManagerOptions (
222238 create = True ,
223- save_interval_steps = save_interval_steps ,
224239 enable_async_checkpointing = use_async ,
225- ),
240+ save_decision_policy = save_decision_policy ,
241+ preservation_policy = preservation_policy ,
242+ ),
226243 logger = orbax_logger ,
227244 )
228245
@@ -259,8 +276,12 @@ def create_orbax_emergency_checkpoint_manager(
259276 global_mesh = global_mesh ,
260277 abstract_state = abstract_state ,
261278 options = emergency_checkpoint_manager .CheckpointManagerOptions (
262- local = LocalCheckpointOptions (save_interval_steps = local_save_interval_steps ),
263- persistent = PersistentCheckpointOptions (save_interval_steps = persistent_save_interval_steps ),
279+ local = LocalCheckpointOptions (
280+ save_interval_steps = local_save_interval_steps
281+ ),
282+ persistent = PersistentCheckpointOptions (
283+ save_interval_steps = persistent_save_interval_steps
284+ ),
264285 ),
265286 logger = orbax_logger ,
266287 )
0 commit comments