Skip to content

Commit 80f3efc

Browse files
Merge pull request #3026 from AI-Hypercomputer:utils_refactor
PiperOrigin-RevId: 862910344
2 parents 04c8b12 + 3455318 commit 80f3efc

140 files changed

Lines changed: 2670 additions & 2590 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmarks/api_server/maxtext_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
from dataclasses import dataclass, field
3636

37-
from MaxText import max_utils, maxengine, pyconfig, multimodal_utils, max_logging
37+
from MaxText import maxengine, pyconfig, multimodal_utils
38+
from maxtext.utils import max_logging, max_utils
3839

3940
# Set TF log level to avoid verbose startup messages.
4041
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

benchmarks/mmlu/mmlu_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757
from tqdm import tqdm
5858

5959
from MaxText import pyconfig
60-
from MaxText import max_logging
61-
from MaxText import max_utils
6260
from MaxText import maxengine
61+
from maxtext.utils import max_logging
62+
from maxtext.utils import max_utils
6363

6464
ASCII_UPPERCASE_A = ord("A") # ASCII value for uppercase 'A'
6565

src/MaxText/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929

3030
from jax.sharding import Mesh
3131

32-
from MaxText import maxtext_utils
33-
from MaxText import model_creation_utils
3432
from MaxText import pyconfig
3533
from MaxText.layers import models
3634
from MaxText import dpo_utils
37-
from MaxText.model_creation_utils import from_config
35+
from maxtext.utils import maxtext_utils
36+
from maxtext.utils import model_creation_utils
37+
from maxtext.utils.model_creation_utils import from_config
3838

3939
Transformer = models.Transformer
4040
transformer_as_linen = models.transformer_as_linen

src/MaxText/benchmark_chunked_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747

4848
from absl import app
4949

50-
from MaxText import max_utils
5150
from MaxText import maxengine
5251
from MaxText import pyconfig
52+
from maxtext.utils import max_utils
5353

5454
_WARMUP_ITERS = 2
5555
_BENCHMARK_ITERS = 5

src/MaxText/configs/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
from pydantic.main import BaseModel
3434
from pydantic.types import PositiveInt, NonNegativeFloat, NonNegativeInt
3535

36-
from MaxText import accelerator_to_spec_map, max_utils
36+
from MaxText import accelerator_to_spec_map
3737
from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode
3838
from MaxText.globals import MAXTEXT_ASSETS_ROOT
39-
from MaxText.utils import gcs_utils
39+
from maxtext.utils import gcs_utils
40+
from maxtext.utils import max_utils
4041

4142
logger = logging.getLogger(__name__)
4243

src/MaxText/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
from jetstream.engine import engine_api
2525

26-
from MaxText import max_utils
2726
from MaxText import maxengine
2827
from MaxText import pyconfig
2928
from MaxText import multimodal_utils
3029
from MaxText.multimodal import preprocessor
3130
from maxtext.common import profiler
31+
from maxtext.utils import max_utils
3232
# Placeholder: internal
3333

3434
# Number of text sequences to process in a single batch.

src/MaxText/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sys
1818
import importlib
1919

20-
from MaxText import max_logging
20+
from maxtext.utils import max_logging
2121

2222
OLD_MODULE_PATH = "MaxText.distillation.train_distill"
2323
NEW_MODULE_PATH = "maxtext.trainers.post_train.distillation.train_distill"

src/MaxText/dpo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax
2020
import jax.numpy as jnp
2121

22-
from MaxText import maxtext_utils
22+
from maxtext.utils import maxtext_utils
2323

2424

2525
def _split_dpo_state(state):

src/MaxText/examples/demo_decoding.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,19 @@
135135
"\n",
136136
"import MaxText as mt\n",
137137
"from MaxText import common_types\n",
138-
"from MaxText import maxtext_utils\n",
139-
"from MaxText import max_logging\n",
140138
"from MaxText import pyconfig\n",
141139
"from MaxText.input_pipeline import _input_pipeline_utils\n",
142140
"from MaxText.utils.ckpt_conversion import to_maxtext\n",
143-
"from maxtext.src.maxtext.inference import inference_utils\n",
141+
"from maxtext.inference import inference_utils\n",
142+
"from maxtext.utils import maxtext_utils\n",
143+
"from maxtext.utils import max_logging\n",
144144
"\n",
145145
"from google.colab import userdata\n",
146146
"from huggingface_hub import login\n",
147147
"\n",
148148
"MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)\n",
149-
"MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n",
150-
"MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n",
149+
"MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n",
150+
"MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n",
151151
"\n",
152152
"nest_asyncio.apply()"
153153
]

0 commit comments

Comments
 (0)