Skip to content

Commit 0a3c106

Browse files
lyglstGoogle-ML-Automation
authored andcommitted
Fix two issues that blocks training loop with continuous checkpoint enabled.
PiperOrigin-RevId: 866204905
1 parent a54b374 commit 0a3c106

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Optional
1919

2020
from absl import flags
21+
import datetime
2122
from etils import epath
2223
from flax.training import train_state
2324
import jax
@@ -248,6 +249,11 @@ def create_orbax_checkpoint_manager(
248249
else:
249250
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
250251
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
252+
async_options = None
253+
if enable_continuous_checkpointing:
254+
async_options = ocp.AsyncOptions(
255+
timeout_secs=int(datetime.timedelta(minutes=60).total_seconds()),
256+
)
251257
manager = CheckpointManager(
252258
p,
253259
item_names=item_names,
@@ -257,6 +263,7 @@ def create_orbax_checkpoint_manager(
257263
enable_async_checkpointing=use_async,
258264
save_decision_policy=save_decision_policy,
259265
preservation_policy=preservation_policy,
266+
async_options=async_options,
260267
),
261268
logger=orbax_logger,
262269
)
@@ -728,6 +735,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
728735
if config and config.enable_checkpointing:
729736
if (
730737
force
738+
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
731739
or (step % config.checkpoint_period == 0)
732740
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
733741
):

0 commit comments

Comments
 (0)