Skip to content

Commit 55278a2

Browse files
Merge pull request #3005 from AI-Hypercomputer:common_modules
PiperOrigin-RevId: 861378591
2 parents 96f1375 + f2afa86 commit 55278a2

34 files changed

Lines changed: 63 additions & 82 deletions
Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -220,16 +220,10 @@ def create_orbax_checkpoint_manager(
220220
p.mkdir(exist_ok=True, parents=True)
221221
if enable_continuous_checkpointing:
222222
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
223-
preservation_policy = preservation_policy_lib.LatestN(
224-
max_num_checkpoints_to_keep
225-
)
223+
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
226224
else:
227-
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(
228-
interval=save_interval_steps
229-
)
230-
preservation_policy = preservation_policy_lib.LatestN(
231-
max_num_checkpoints_to_keep
232-
)
225+
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
226+
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
233227
manager = CheckpointManager(
234228
p,
235229
item_names=item_names,
@@ -239,7 +233,7 @@ def create_orbax_checkpoint_manager(
239233
enable_async_checkpointing=use_async,
240234
save_decision_policy=save_decision_policy,
241235
preservation_policy=preservation_policy,
242-
),
236+
),
243237
logger=orbax_logger,
244238
)
245239

@@ -276,12 +270,8 @@ def create_orbax_emergency_checkpoint_manager(
276270
global_mesh=global_mesh,
277271
abstract_state=abstract_state,
278272
options=emergency_checkpoint_manager.CheckpointManagerOptions(
279-
local=LocalCheckpointOptions(
280-
save_interval_steps=local_save_interval_steps
281-
),
282-
persistent=PersistentCheckpointOptions(
283-
save_interval_steps=persistent_save_interval_steps
284-
),
273+
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
274+
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
285275
),
286276
logger=orbax_logger,
287277
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from MaxText import exceptions
2323
from MaxText.sharding import get_input_data_sharding
24-
from MaxText.utils.goodput_utils import (
24+
from MaxText.common.goodput import (
2525
GoodputEvent,
2626
maybe_record_goodput,
2727
)
File renamed without changes.
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
from MaxText import maxtext_utils
3333
from MaxText.managed_mldiagnostics import ManagedMLDiagnostics
3434
from MaxText.utils import gcs_utils
35-
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
3635
from MaxText.globals import EPS
37-
36+
from MaxText.common.gcp_workload_monitor import GCPWorkloadMonitor
3837
from collections import defaultdict
3938

4039
# Mapping MaxText metrics to managed profiler metrics

src/MaxText/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from MaxText import max_utils
2727
from MaxText import maxengine
2828
from MaxText import pyconfig
29-
from MaxText import profiler
3029
from MaxText import multimodal_utils
3130
from MaxText.multimodal import preprocessor
31+
from maxtext.common import profiler
3232
# Placeholder: internal
3333

3434
# Number of text sequences to process in a single batch.

src/MaxText/elastic_train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,26 @@
6060

6161
import tensorflow as tf
6262

63-
from MaxText import checkpointing
6463
from MaxText import exceptions
6564
from MaxText import max_utils
6665
from MaxText import maxtext_utils
6766
from MaxText import train_utils
6867
from MaxText import max_logging
69-
from MaxText import profiler
7068
from MaxText import pyconfig
71-
from MaxText.data_loader import DataLoader
72-
from MaxText.metric_logger import MetricLogger
7369
from MaxText.train import get_first_step
7470
from MaxText.train_utils import setup_train_loop
7571
from MaxText.train import train_step
7672
from MaxText.train_utils import validate_train_config
77-
from MaxText.utils.goodput_utils import (
73+
from maxtext.common import checkpointing, profiler
74+
from MaxText.common.data_loader import DataLoader
75+
from MaxText.common.goodput import (
7876
GoodputEvent,
7977
create_goodput_recorder,
8078
maybe_monitor_goodput,
8179
maybe_record_goodput,
8280
)
83-
from MaxText.vertex_tensorboard import VertexTensorboardManager
81+
from MaxText.common.metric_logger import MetricLogger
82+
from MaxText.common.vertex_tensorboard import VertexTensorboardManager
8483

8584
logging.basicConfig()
8685
logging.getLogger("pathwaysutils.elastic.manager").setLevel(logging.INFO)

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,30 @@
6767
from ml_goodput_measurement.src.goodput import GoodputRecorder
6868

6969
import MaxText as mt
70-
from MaxText import checkpointing
7170
from MaxText import exceptions
7271
from MaxText import max_logging
7372
from MaxText import max_utils
7473
from MaxText import maxtext_utils
7574
from MaxText import sharding
7675
from MaxText import train_utils
77-
from MaxText import profiler
7876
from MaxText import pyconfig
79-
from MaxText.checkpointing import CheckpointManager
8077
from MaxText.utils import gcs_utils
8178
from MaxText.inference import offline_engine
82-
from MaxText.data_loader import DataLoader
8379
from MaxText.experimental.rl import grpo_input_pipeline
8480
from MaxText.experimental.rl import grpo_utils
8581
from MaxText.globals import EPS
86-
from MaxText.metric_logger import MetricLogger
8782
from MaxText.train import get_first_step
8883
from MaxText.train_utils import validate_train_config
89-
from MaxText.utils.goodput_utils import (
84+
from maxtext.common import checkpointing, profiler
85+
from MaxText.common.data_loader import DataLoader
86+
from MaxText.common.goodput import (
9087
GoodputEvent,
9188
create_goodput_recorder,
9289
maybe_monitor_goodput,
9390
maybe_record_goodput,
9491
)
95-
from MaxText.vertex_tensorboard import VertexTensorboardManager
96-
92+
from MaxText.common.metric_logger import MetricLogger
93+
from MaxText.common.vertex_tensorboard import VertexTensorboardManager
9794

9895
# pylint: disable=too-many-positional-arguments
9996

@@ -505,7 +502,7 @@ def setup_train_loop(
505502
recorder: GoodputRecorder,
506503
) -> tuple[
507504
jax.Array,
508-
CheckpointManager,
505+
checkpointing.CheckpointManager,
509506
TrainState,
510507
TrainState,
511508
mt.Transformer,

0 commit comments

Comments
 (0)