Skip to content

Commit 305e4e0

Browse files
SurbhiJainUSCGoogle-ML-Automation
authored andcommitted
Update import paths for maxtext_utils and pyconfig in tests
PiperOrigin-RevId: 877446791
1 parent 5dfd6a4 commit 305e4e0

5 files changed

Lines changed: 6 additions & 7 deletions

File tree

src/install_maxtext_extra_deps/install_post_train_extra_deps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
"uv",
7272
"pip",
7373
"install",
74-
"src/maxtext/integration/vllm",
74+
"src/MaxText/integration/vllm",
7575
"--no-deps",
7676
]
7777

tests/unit/engram_vs_reference_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@
4646
from jax.sharding import Mesh
4747

4848
from maxtext.configs import pyconfig
49-
from maxtext import maxtext_utils
50-
5149
from maxtext.layers.engram import CompressedTokenizer as CompressedTokenizerJAX
5250
from maxtext.layers.engram import NgramHashMapping as NgramHashMappingJAX
5351
from maxtext.layers.engram import MultiHeadEmbedding as MultiHeadEmbeddingJAX
5452
from maxtext.layers.engram import ShortConv as ShortConvJAX
5553
from maxtext.layers.engram import Engram as EngramJAX
54+
from maxtext.utils import maxtext_utils
5655
from maxtext.utils.globals import MAXTEXT_PKG_DIR
5756

5857

tests/unit/moba_vs_reference_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
import torch
3232
from jax.sharding import Mesh
3333

34-
from maxtext import maxtext_utils, pyconfig
34+
from maxtext.configs import pyconfig
3535
from maxtext.layers.attention_op import AttentionOp
36+
from maxtext.utils import maxtext_utils
3637
from tests.utils.test_helpers import get_test_config_path
3738

3839
# pylint: disable=missing-function-docstring,protected-access

tests/unit/sharding_compare_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import os
2020
import jax
2121
import jax.numpy as jnp
22-
from maxtext import maxtext_utils
2322
from maxtext.configs import pyconfig
23+
from maxtext.utils import maxtext_utils
2424
# import optax
2525

2626
from maxtext.utils.globals import MAXTEXT_PKG_DIR

tests/utils/sharding_dump.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
import jax
2727
from jax.sharding import NamedSharding, PartitionSpec
2828
from jax.tree_util import tree_flatten_with_path
29-
from maxtext import maxtext_utils
3029
from maxtext.configs import pyconfig
31-
30+
from maxtext.utils import maxtext_utils
3231
from maxtext.utils.globals import MAXTEXT_REPO_ROOT
3332
from maxtext.utils.sharding import _ACTIVATION_SHARDINGS_DUMP
3433

0 commit comments

Comments
 (0)