Skip to content

Commit ac6ba4a

Browse files
committed
Move layers and models into new folders
1 parent f70f5c8 commit ac6ba4a

80 files changed

Lines changed: 349 additions & 344 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/MaxText/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from jax.sharding import Mesh
3737

3838
from MaxText import pyconfig
39-
from MaxText.layers import models
39+
from maxtext.models import models
4040
from maxtext.trainers.post_train.dpo import dpo_utils
4141
from maxtext.utils import maxtext_utils
4242
from maxtext.utils import model_creation_utils

src/MaxText/generate_param_only_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from MaxText import optimizers
3636
from MaxText import pyconfig
3737
from MaxText.common_types import DecoderBlockType, MODEL_MODE_TRAIN
38-
from MaxText.layers import models, quantizations
38+
from maxtext.models import models
39+
from maxtext.layers import quantizations
3940
from maxtext.common import checkpointing
4041
from maxtext.utils import gcs_utils
4142
from maxtext.utils import lora_utils

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from jax import Array
2727
from flax import nnx
28-
from MaxText.layers.models import Transformer
28+
from maxtext.models.models import Transformer
2929
from MaxText.integration.tunix.utils import VllmWeightMapping
3030
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
3131

src/MaxText/layerwise_quantization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545

4646
from MaxText import common_types
4747
from MaxText import pyconfig
48-
from MaxText.layers import models, quantizations, deepseek
48+
from maxtext.models import models, deepseek
49+
from maxtext.layers import quantizations
4950
from maxtext.common import checkpointing
5051
from maxtext.utils import max_logging
5152
from maxtext.utils import max_utils

src/MaxText/maxengine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from MaxText import pyconfig
4040
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
4141
from MaxText.globals import MAXTEXT_PKG_DIR
42-
from MaxText.layers import models, quantizations
42+
from maxtext.models import models
43+
from maxtext.layers import quantizations
4344
from maxtext.inference import inference_utils
4445
from maxtext.inference.page_manager import PageManager, PageState
4546
from maxtext.multimodal import processor as mm_processor

src/MaxText/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@
4040

4141
from MaxText import pyconfig
4242
from MaxText import sharding
43-
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
4443
from MaxText.common_types import ShardMode
4544
from MaxText.globals import EPS
4645
# Placeholder: internal
4746

4847
from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad
4948
from MaxText.vocabulary_tiling import vocab_tiling_linen_loss
5049
# pylint: disable=too-many-positional-arguments
51-
50+
from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5251
from maxtext.common import checkpointing, profiler
5352
from maxtext.common.goodput import (
5453
GoodputEvent,

src/MaxText/train_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from MaxText import pyconfig
4242
from MaxText import sharding
4343
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
44-
from MaxText.layers import models
45-
from MaxText.layers import quantizations
44+
from maxtext.models import models
45+
from maxtext.layers import quantizations
4646
from maxtext.utils import gcs_utils
4747
from maxtext.utils import max_utils
4848
from maxtext.utils import maxtext_utils

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
from MaxText import pyconfig
5454
from MaxText.common_types import MODEL_MODE_TRAIN
5555
from MaxText.globals import MAXTEXT_PKG_DIR
56-
from MaxText.layers import quantizations
57-
from MaxText.layers.models import transformer_as_linen
56+
from maxtext.layers import quantizations
57+
from maxtext.models.models import transformer_as_linen
5858
from maxtext.common import checkpointing
5959
from maxtext.utils import max_logging
6060
from maxtext.utils import maxtext_utils

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@
7777
from orbax.checkpoint import type_handlers
7878
from MaxText import pyconfig
7979
from MaxText.common_types import MODEL_MODE_TRAIN
80-
from MaxText.layers import models, quantizations
8180
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
8281
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
8382
from maxtext.checkpoint_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, MemoryMonitorTqdm, print_peak_memory, validate_and_filter_param_map_keys
8483
from maxtext.inference.inference_utils import str2bool
84+
from maxtext.models import models
85+
from maxtext.layers import quantizations
8586
from maxtext.utils import max_logging
8687
from maxtext.utils import max_utils
8788
from maxtext.utils import maxtext_utils

src/maxtext/inference/kvcache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from aqt.jax.v2.aqt_tensor import QTensor as KVTensor
2727
from aqt.jax.v2.flax import aqt_flax
2828

29-
from MaxText.layers import nnx_wrappers
30-
from MaxText.layers.initializers import variable_to_logically_partitioned
29+
from maxtext.layers import nnx_wrappers
30+
from maxtext.layers.initializers import variable_to_logically_partitioned
3131

3232
from MaxText.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR
3333
from MaxText.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV

0 commit comments

Comments
 (0)