Skip to content

Commit be963a5

Browse files
bvandermoonGoogle-ML-Automation
authored andcommitted
PR #3226: Create optimizers directory and move some files to utils
Imported from GitHub PR #3226 # Description * Create optimizers directory * Move some remaining files from `src/MaxText` to `src/maxtext/utils` # Tests ``` python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ run_name=<run_name> \ base_output_directory=gs://<gcs_bucket> \ dataset_type=synthetic \ steps=10 \ enable_checkpointing=false ``` # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation. Copybara import of the project: -- 1e4bdcb by Branden Vandermoon <bvandermoon@google.com>: Create optimizers directory and move some files to utils Merging this change closes #3226 COPYBARA_INTEGRATE_REVIEW=#3226 from AI-Hypercomputer:bvandermoon-restructure-last-files 1e4bdcb PiperOrigin-RevId: 874891637
1 parent 596cf78 commit be963a5

38 files changed

Lines changed: 50 additions & 50 deletions

src/MaxText/generate_param_only_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
import jax
3030
from jax import random
3131
from jax.sharding import Mesh
32-
from MaxText import optimizers
3332
from MaxText import pyconfig
3433
from maxtext.common import checkpointing
3534
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN
3635
from maxtext.layers import quantizations
3736
from maxtext.models import models
37+
from maxtext.optimizers import optimizers
3838
from maxtext.utils import gcs_utils
3939
from maxtext.utils import lora_utils
4040
from maxtext.utils import max_logging

src/MaxText/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from flax.linen import partitioning as nn_partitioning
2929

3030
from MaxText import pyconfig
31-
from MaxText import sharding
3231
from maxtext.trainers.pre_train.train import (
3332
eval_step,
3433
get_first_step,
@@ -48,6 +47,7 @@
4847
from maxtext.utils import max_utils
4948
from maxtext.utils import max_logging
5049
from maxtext.utils import maxtext_utils
50+
from maxtext.utils import sharding
5151
from maxtext.utils import train_utils
5252

5353

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
import jax
4343
from jax import random
4444
from jax.sharding import Mesh
45-
from MaxText import optimizers
4645
from MaxText import pyconfig
4746
from MaxText.globals import MAXTEXT_PKG_DIR
4847
from maxtext.common import checkpointing
4948
from maxtext.common.common_types import MODEL_MODE_TRAIN
5049
from maxtext.layers import quantizations
5150
from maxtext.models.models import transformer_as_linen
51+
from maxtext.optimizers import optimizers
5252
from maxtext.utils import max_logging
5353
from maxtext.utils import max_utils
5454
from maxtext.utils import maxtext_utils

src/maxtext/common/data_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
import jax.numpy as jnp
2020
from jax.experimental import checkify
2121

22-
from MaxText.sharding import get_input_data_sharding
2322
from maxtext.common.goodput import (
2423
GoodputEvent,
2524
maybe_record_goodput,
2625
)
27-
from maxtext.utils import exceptions
2826
from maxtext.trainers.diloco import diloco
27+
from maxtext.utils import exceptions
28+
from maxtext.utils.sharding import get_input_data_sharding
2929

3030

3131
class DataLoader:

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from ml_goodput_measurement.src.goodput import GoodputRecorder
6868

6969
import MaxText as mt
70-
from MaxText import sharding
7170
from MaxText import pyconfig
7271
from MaxText.globals import EPS
7372
from maxtext.trainers.pre_train.train import get_first_step
@@ -89,6 +88,7 @@
8988
from maxtext.utils import max_logging
9089
from maxtext.utils import max_utils
9190
from maxtext.utils import maxtext_utils
91+
from maxtext.utils import sharding
9292
from maxtext.utils import train_utils
9393

9494
# pylint: disable=too-many-positional-arguments

src/maxtext/inference/paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from maxtext.inference import page_manager
3131
from maxtext.inference import paged_attention_kernel_v2
3232
from maxtext.layers.initializers import variable_to_logically_partitioned
33-
from MaxText.sharding import logical_to_mesh_axes
33+
from maxtext.utils.sharding import logical_to_mesh_axes
3434

3535
_use_kernel_v2 = False
3636

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
DEFAULT_MASK_VALUE,
6363
)
6464

65-
from MaxText.sharding import create_sharding
6665
from maxtext.layers import nnx_wrappers
6766
from maxtext.layers.attentions import Attention
6867
from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
@@ -73,6 +72,7 @@
7372
from maxtext.inference import page_manager
7473
from maxtext.inference import paged_attention
7574
from maxtext.inference.kvcache import KVQuant
75+
from maxtext.utils.sharding import create_sharding
7676

7777

7878
class Indexer(nnx.Module):

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
from maxtext.layers import nnx_wrappers
7878
from maxtext.layers.initializers import variable_to_logically_partitioned
7979
from maxtext.layers.quantizations import AqtQuantization as Quant
80-
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
8180
from maxtext.utils import max_utils
81+
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name
8282
import numpy as np
8383
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
8484
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask

src/maxtext/layers/attentions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
EP_AS_CONTEXT,
5454
AttentionType,
5555
)
56-
from MaxText.sharding import maybe_shard_with_logical, create_sharding
5756
from maxtext.layers import nnx_wrappers
5857
from maxtext.layers.attention_op import AttentionOp
5958
from maxtext.layers.embeddings import (
@@ -71,6 +70,7 @@
7170
from maxtext.layers.quantizations import AqtQuantization as Quant
7271
from maxtext.inference import kvcache, page_manager, paged_attention
7372
from maxtext.inference.kvcache import KVQuant
73+
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding
7474

7575
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
7676
# pytype: disable=attribute-error

src/maxtext/layers/decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from jax.ad_checkpoint import checkpoint_name
2828
import jax.numpy as jnp
2929
from jax.sharding import Mesh
30-
from MaxText import sharding
3130
from maxtext.common.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode
3231
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN
3332
from maxtext.inference import page_manager
@@ -56,10 +55,11 @@
5655
simple_layer,
5756
)
5857
from maxtext.multimodal import utils as mm_utils
59-
from MaxText.sharding import create_sharding
58+
from maxtext.utils.sharding import create_sharding
6059
from maxtext.utils import max_logging
6160
from maxtext.utils import max_utils
6261
from maxtext.utils import maxtext_utils
62+
from maxtext.utils import sharding
6363

6464
# ------------------------------------------------------------------------------
6565
# The network: Decoder Definitions

0 commit comments

Comments
 (0)