Skip to content

Commit f2afa86

Browse files
committed
Move common modules to src/maxtext/common
1 parent b646a53 commit f2afa86

34 files changed

Lines changed: 62 additions & 81 deletions

src/MaxText/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from MaxText import max_utils
2727
from MaxText import maxengine
2828
from MaxText import pyconfig
29-
from MaxText import profiler
3029
from MaxText import multimodal_utils
3130
from MaxText.multimodal import preprocessor
31+
from maxtext.common import profiler
3232
# Placeholder: internal
3333

3434
# Number of text sequences to process in a single batch.

src/MaxText/elastic_train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,26 @@
6060

6161
import tensorflow as tf
6262

63-
from MaxText import checkpointing
6463
from MaxText import exceptions
6564
from MaxText import max_utils
6665
from MaxText import maxtext_utils
6766
from MaxText import train_utils
6867
from MaxText import max_logging
69-
from MaxText import profiler
7068
from MaxText import pyconfig
71-
from MaxText.data_loader import DataLoader
72-
from MaxText.metric_logger import MetricLogger
7369
from MaxText.train import get_first_step
7470
from MaxText.train_utils import setup_train_loop
7571
from MaxText.train import train_step
7672
from MaxText.train_utils import validate_train_config
77-
from MaxText.utils.goodput_utils import (
73+
from maxtext.common import checkpointing, profiler
74+
from maxtext.common.data_loader import DataLoader
75+
from maxtext.common.goodput import (
7876
GoodputEvent,
7977
create_goodput_recorder,
8078
maybe_monitor_goodput,
8179
maybe_record_goodput,
8280
)
83-
from MaxText.vertex_tensorboard import VertexTensorboardManager
81+
from maxtext.common.metric_logger import MetricLogger
82+
from maxtext.common.vertex_tensorboard import VertexTensorboardManager
8483

8584
logging.basicConfig()
8685
logging.getLogger("pathwaysutils.elastic.manager").setLevel(logging.INFO)

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,30 @@
6767
from ml_goodput_measurement.src.goodput import GoodputRecorder
6868

6969
import MaxText as mt
70-
from MaxText import checkpointing
7170
from MaxText import exceptions
7271
from MaxText import max_logging
7372
from MaxText import max_utils
7473
from MaxText import maxtext_utils
7574
from MaxText import sharding
7675
from MaxText import train_utils
77-
from MaxText import profiler
7876
from MaxText import pyconfig
79-
from MaxText.checkpointing import CheckpointManager
8077
from MaxText.utils import gcs_utils
8178
from MaxText.inference import offline_engine
82-
from MaxText.data_loader import DataLoader
8379
from MaxText.experimental.rl import grpo_input_pipeline
8480
from MaxText.experimental.rl import grpo_utils
8581
from MaxText.globals import EPS
86-
from MaxText.metric_logger import MetricLogger
8782
from MaxText.train import get_first_step
8883
from MaxText.train_utils import validate_train_config
89-
from MaxText.utils.goodput_utils import (
84+
from maxtext.common import checkpointing, profiler
85+
from maxtext.common.data_loader import DataLoader
86+
from maxtext.common.goodput import (
9087
GoodputEvent,
9188
create_goodput_recorder,
9289
maybe_monitor_goodput,
9390
maybe_record_goodput,
9491
)
95-
from MaxText.vertex_tensorboard import VertexTensorboardManager
96-
92+
from maxtext.common.metric_logger import MetricLogger
93+
from maxtext.common.vertex_tensorboard import VertexTensorboardManager
9794

9895
# pylint: disable=too-many-positional-arguments
9996

@@ -505,7 +502,7 @@ def setup_train_loop(
505502
recorder: GoodputRecorder,
506503
) -> tuple[
507504
jax.Array,
508-
CheckpointManager,
505+
checkpointing.CheckpointManager,
509506
TrainState,
510507
TrainState,
511508
mt.Transformer,

src/MaxText/gcloud_stub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def workload_monitor():
445445
return _workload_monitor_stub()
446446

447447
try:
448-
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor # type: ignore # pylint: disable=import-outside-toplevel
448+
from maxtext.common.gcp_workload_monitor import GCPWorkloadMonitor # type: ignore # pylint: disable=import-outside-toplevel
449449

450450
return GCPWorkloadMonitor, False
451451
except Exception: # ModuleNotFoundError / ImportError # pylint: disable=broad-exception-caught
@@ -484,7 +484,7 @@ def vertex_tensorboard_components():
484484
return _vertex_tb_stub()
485485

486486
try:
487-
from MaxText.vertex_tensorboard import VertexTensorboardManager # type: ignore # pylint: disable=import-outside-toplevel
487+
from maxtext.common.vertex_tensorboard import VertexTensorboardManager # type: ignore # pylint: disable=import-outside-toplevel
488488

489489
return VertexTensorboardManager, False
490490
except Exception: # pylint: disable=broad-exception-caught

src/MaxText/generate_param_only_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from jax.sharding import Mesh
3333
from jax import random
3434

35-
from MaxText import checkpointing
3635
from MaxText import max_logging
3736
from MaxText import max_utils
3837
from MaxText import maxtext_utils
@@ -42,6 +41,7 @@
4241
from MaxText.layers import models, quantizations
4342
from MaxText.utils import gcs_utils
4443
from MaxText.utils import lora_utils
44+
from maxtext.common import checkpointing
4545

4646
Transformer = models.transformer_as_linen
4747

src/MaxText/inference_microbenchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from MaxText import maxengine
2727
from MaxText import maxtext_utils
2828
from MaxText import prefill_packing
29-
from MaxText import profiler
3029
from MaxText import pyconfig
3130
from MaxText.utils import gcs_utils
31+
from maxtext.common import profiler
3232

3333
import warnings
3434

src/MaxText/layerwise_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
from flax.linen import partitioning as nn_partitioning
4444
from flax import nnx
4545

46-
from MaxText import checkpointing
4746
from MaxText import common_types
4847
from MaxText import max_logging
4948
from MaxText import max_utils
5049
from MaxText import maxtext_utils
5150
from MaxText import pyconfig
5251
from MaxText.layers import models, quantizations, deepseek
52+
from maxtext.common import checkpointing
5353
import orbax.checkpoint as ocp
5454

5555
IGNORE = ocp.PLACEHOLDER

src/MaxText/maxtext_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3636
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3737

38-
from MaxText import checkpointing
3938
from MaxText import max_logging
4039
from MaxText import max_utils
4140
from MaxText import multimodal_utils
4241
from MaxText import sharding
4342
from MaxText.configs import types
4443
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4544
from MaxText.inference.page_manager import PageState
45+
from maxtext.common import checkpointing
4646

4747
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
4848

src/MaxText/sft_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,29 @@
2727

2828
from flax.linen import partitioning as nn_partitioning
2929

30-
from MaxText import checkpointing
3130
from MaxText import exceptions
3231
from MaxText import max_utils
3332
from MaxText import max_logging
3433
from MaxText import maxtext_utils
35-
from MaxText import profiler
3634
from MaxText import pyconfig
3735
from MaxText import train_utils
3836
from MaxText import sharding
39-
from MaxText.data_loader import DataLoader
40-
from MaxText.metric_logger import MetricLogger
4137
from MaxText.train import (
4238
eval_step,
4339
get_first_step,
4440
train_step,
4541
)
4642
from MaxText.train_utils import setup_train_loop, validate_train_config
4743
from MaxText.utils import gcs_utils
48-
from MaxText.utils.goodput_utils import (
44+
from maxtext.common import checkpointing, profiler
45+
from maxtext.common.data_loader import DataLoader
46+
from maxtext.common.goodput import (
4947
GoodputEvent,
5048
create_goodput_recorder,
5149
maybe_monitor_goodput,
5250
maybe_record_goodput,
5351
)
52+
from maxtext.common.metric_logger import MetricLogger
5453

5554

5655
def train_loop(config, recorder, state=None):

src/MaxText/train.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,36 +42,35 @@
4242
from cloud_tpu_diagnostics.configuration import diagnostic_configuration
4343
from cloud_tpu_diagnostics.configuration import stack_trace_configuration
4444

45-
from MaxText import checkpointing
4645
from MaxText import exceptions
4746
from MaxText import max_logging
4847
from MaxText import max_utils
4948
from MaxText import maxtext_utils
5049
from MaxText import train_utils
51-
from MaxText import profiler
5250
from MaxText import pyconfig
5351
from MaxText import sharding
5452
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5553
from MaxText.common_types import ShardMode
5654
from MaxText.globals import EPS
57-
from MaxText.metric_logger import MetricLogger
5855
from MaxText.utils import gcs_utils
59-
from MaxText.utils.goodput_utils import (
60-
GoodputEvent,
61-
create_goodput_recorder,
62-
maybe_monitor_goodput,
63-
maybe_record_goodput,
64-
)
65-
from MaxText.vertex_tensorboard import VertexTensorboardManager
6656
# Placeholder: internal
6757

6858
from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad
6959
from MaxText.vocabulary_tiling import vocab_tiling_linen_loss
7060
from MaxText.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
7161
from MaxText.train_utils import validate_train_config
72-
from MaxText.metric_logger import record_activation_metrics
7362
# pylint: disable=too-many-positional-arguments
7463

64+
from maxtext.common import checkpointing, profiler
65+
from maxtext.common.goodput import (
66+
GoodputEvent,
67+
create_goodput_recorder,
68+
maybe_monitor_goodput,
69+
maybe_record_goodput,
70+
)
71+
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
72+
from maxtext.common.vertex_tensorboard import VertexTensorboardManager
73+
7574

7675
def get_first_step(state):
7776
return int(state.step)

0 commit comments

Comments
 (0)