Skip to content

Commit fdef529

Browse files
Merge pull request #3054 from AI-Hypercomputer:move-maxtext-layers-2
PiperOrigin-RevId: 872641703
2 parents 5f1717b + ac6ba4a commit fdef529

80 files changed

Lines changed: 2285 additions & 1334 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/MaxText/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from jax.sharding import Mesh
3737

3838
from MaxText import pyconfig
39-
from MaxText.layers import models
39+
from maxtext.models import models
4040
from maxtext.trainers.post_train.dpo import dpo_utils
4141
from maxtext.utils import maxtext_utils
4242
from maxtext.utils import model_creation_utils

src/MaxText/generate_param_only_checkpoint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,16 @@
2525
from typing import Sequence
2626

2727
from absl import app
28-
2928
from etils import epath
30-
3129
import jax
32-
from jax.sharding import Mesh
3330
from jax import random
34-
31+
from jax.sharding import Mesh
3532
from MaxText import optimizers
3633
from MaxText import pyconfig
37-
from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN
38-
from MaxText.layers import models, quantizations
3934
from maxtext.common import checkpointing
35+
from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN
36+
from maxtext.layers import quantizations
37+
from maxtext.models import models
4038
from maxtext.utils import gcs_utils
4139
from maxtext.utils import lora_utils
4240
from maxtext.utils import max_logging

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Optional, Tuple, Any
24+
from typing import Any, Optional, Tuple
2525

26-
from jax import Array
2726
from flax import nnx
28-
from MaxText.layers.models import Transformer
29-
from MaxText.integration.tunix.utils import VllmWeightMapping
27+
from jax import Array
3028
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
29+
from MaxText.integration.tunix.utils import VllmWeightMapping
30+
from maxtext.models.models import Transformer
3131

3232

3333
class TunixMaxTextAdapter(nnx.Module):

src/MaxText/layerwise_quantization.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,22 @@
3333
import os
3434
from typing import Any, Sequence
3535

36-
from tqdm import tqdm
37-
38-
import jax
39-
import jax.numpy as jnp
4036
from absl import app
4137
from aqt.jax.v2 import aqt_tensor
42-
43-
from flax.linen import partitioning as nn_partitioning
4438
from flax import nnx
45-
39+
from flax.linen import partitioning as nn_partitioning
40+
import jax
41+
import jax.numpy as jnp
4642
from MaxText import common_types
4743
from MaxText import pyconfig
48-
from MaxText.layers import models, quantizations, deepseek
4944
from maxtext.common import checkpointing
45+
from maxtext.layers import quantizations
46+
from maxtext.models import deepseek, models
5047
from maxtext.utils import max_logging
5148
from maxtext.utils import max_utils
5249
from maxtext.utils import maxtext_utils
5350
import orbax.checkpoint as ocp
51+
from tqdm import tqdm
5452

5553
IGNORE = ocp.PLACEHOLDER
5654
PRNGKeyType = Any

src/MaxText/maxengine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from MaxText import pyconfig
4040
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
4141
from MaxText.globals import MAXTEXT_PKG_DIR
42-
from MaxText.layers import models, quantizations
42+
from maxtext.models import models
43+
from maxtext.layers import quantizations
4344
from maxtext.inference import inference_utils
4445
from maxtext.inference.page_manager import PageManager, PageState
4546
from maxtext.multimodal import processor as mm_processor

src/MaxText/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@
4040

4141
from MaxText import pyconfig
4242
from MaxText import sharding
43-
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
4443
from MaxText.common_types import ShardMode
4544
from MaxText.globals import EPS
4645
# Placeholder: internal
4746

4847
from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad
4948
from MaxText.vocabulary_tiling import vocab_tiling_linen_loss
5049
# pylint: disable=too-many-positional-arguments
51-
50+
from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5251
from maxtext.common import checkpointing, profiler
5352
from maxtext.common.goodput import (
5453
GoodputEvent,

src/MaxText/train_compile.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,29 @@
2121
as you would on the target hardware.
2222
"""
2323

24-
from typing import Sequence
24+
import functools
2525
import os
2626
import pickle
27-
import functools
27+
from typing import Sequence
2828

2929
from absl import app
30-
30+
from flax.linen import partitioning as nn_partitioning
3131
import jax
32-
from jax.experimental.topologies import get_topology_desc
33-
from jax.sharding import Mesh, AxisType
3432
from jax.experimental.serialize_executable import serialize
35-
36-
from flax.linen import partitioning as nn_partitioning
37-
33+
from jax.experimental.topologies import get_topology_desc
34+
from jax.sharding import AxisType, Mesh
3835
from MaxText import accelerator_to_spec_map
39-
from MaxText import train
4036
from MaxText import optimizers
4137
from MaxText import pyconfig
4238
from MaxText import sharding
39+
from MaxText import train
4340
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
44-
from MaxText.layers import models
45-
from MaxText.layers import quantizations
41+
from maxtext.layers import quantizations
42+
from maxtext.models import models
43+
from maxtext.trainers.diloco import diloco
4644
from maxtext.utils import gcs_utils
4745
from maxtext.utils import max_utils
4846
from maxtext.utils import maxtext_utils
49-
from maxtext.trainers.diloco import diloco
5047

5148
# pylint: disable=too-many-positional-arguments
5249

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,22 @@
3939
import os
4040
import sys
4141

42-
from psutil import Process
43-
44-
import numpy as np
45-
4642
import jax
4743
from jax import random
4844
from jax.sharding import Mesh
49-
50-
import tensorstore as ts
51-
5245
from MaxText import optimizers
5346
from MaxText import pyconfig
47+
from maxtext.common import checkpointing
5448
from MaxText.common_types import MODEL_MODE_TRAIN
5549
from MaxText.globals import MAXTEXT_PKG_DIR
56-
from MaxText.layers import quantizations
57-
from MaxText.layers.models import transformer_as_linen
58-
from maxtext.common import checkpointing
50+
from maxtext.layers import quantizations
51+
from maxtext.models.models import transformer_as_linen
5952
from maxtext.utils import max_logging
60-
from maxtext.utils import maxtext_utils
6153
from maxtext.utils import max_utils
54+
from maxtext.utils import maxtext_utils
55+
import numpy as np
56+
from psutil import Process
57+
import tensorstore as ts
6258

6359

6460
def fmt_size(num_bytes: int) -> str:

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,32 @@
5959
"""
6060

6161
import argparse
62+
from functools import partial
63+
import json
6264
import os
63-
import time
6465
import sys
65-
import json
6666
import threading
67-
from functools import partial
68-
from typing import Sequence, List, Any, Callable
69-
import numpy as np
67+
import time
68+
from typing import Any, Callable, List, Sequence
7069
import absl
71-
72-
from transformers import AutoConfig
70+
import flax.linen as nn
7371
from huggingface_hub import hf_hub_download, list_repo_files
74-
from safetensors import safe_open
7572
import jax
76-
import flax.linen as nn
77-
from orbax.checkpoint import type_handlers
7873
from MaxText import pyconfig
79-
from MaxText.common_types import MODEL_MODE_TRAIN
80-
from MaxText.layers import models, quantizations
8174
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
8275
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
83-
from maxtext.checkpoint_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, MemoryMonitorTqdm, print_peak_memory, validate_and_filter_param_map_keys
76+
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
77+
from MaxText.common_types import MODEL_MODE_TRAIN
8478
from maxtext.inference.inference_utils import str2bool
79+
from maxtext.layers import quantizations
80+
from maxtext.models import models
8581
from maxtext.utils import max_logging
8682
from maxtext.utils import max_utils
8783
from maxtext.utils import maxtext_utils
84+
import numpy as np
85+
from orbax.checkpoint import type_handlers
86+
from safetensors import safe_open
87+
from transformers import AutoConfig
8888

8989

9090
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log

src/maxtext/inference/kvcache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from aqt.jax.v2.aqt_tensor import QTensor as KVTensor
2727
from aqt.jax.v2.flax import aqt_flax
2828

29-
from MaxText.layers import nnx_wrappers
30-
from MaxText.layers.initializers import variable_to_logically_partitioned
29+
from maxtext.layers import nnx_wrappers
30+
from maxtext.layers.initializers import variable_to_logically_partitioned
3131

3232
from MaxText.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR
3333
from MaxText.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV

0 commit comments

Comments
 (0)