Skip to content

Commit 1b0c210

Browse files
Merge pull request #3124 from AI-Hypercomputer:aireen/input_restructure2
PiperOrigin-RevId: 869967557
2 parents 3c56dd3 + 9823136 commit 1b0c210

40 files changed

Lines changed: 179 additions & 155 deletions

src/MaxText/layers/engram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from jax.sharding import Mesh
3030
from flax import nnx
3131

32-
from MaxText.tokenizer import HFTokenizer
32+
from maxtext.input_pipeline.tokenizer import HFTokenizer
3333
from MaxText.common_types import MODEL_MODE_TRAIN, Array, Config
3434
from MaxText.layers.embeddings import Embed
3535
from MaxText.layers.initializers import nd_dense_init, NdInitializer

src/MaxText/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7777
from MaxText.rl.evaluate_rl import evaluate
7878
from MaxText.rl import utils_rl
79-
from MaxText.input_pipeline.instruction_data_processing import load_template_from_file
79+
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
8080
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
8181

8282

src/maxtext/common/checkpointing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from flax.training import train_state
2424
import jax
2525
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
26-
from MaxText.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
27-
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
26+
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator
27+
from maxtext.input_pipeline.multihost_dataloading import RemoteIterator
28+
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2829
from maxtext.utils import exceptions
2930
from maxtext.utils import max_logging
3031
import numpy as np

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ add_eos: True
559559
# If False, use chunking for long sequences instead of truncation.
560560
# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline.
561561
# See the TokenizeAndTrim and TokenizeAndChunk classes in
562-
# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details.
562+
# `src/maxtext/input_pipeline/_grain_tokenizer.py` for implementation details.
563563
use_truncation: True
564564

565565
# Dataset

src/maxtext/examples/sft_train_and_evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@
8585

8686
from flax import nnx
8787

88-
from MaxText.globals import MAXTEXT_REPO_ROOT
8988
from MaxText import pyconfig
90-
from MaxText.input_pipeline import instruction_data_processing
89+
from MaxText.globals import MAXTEXT_REPO_ROOT
9190
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
91+
from maxtext.input_pipeline import instruction_data_processing
9292
from maxtext.trainers.post_train.sft import train_sft
9393
from maxtext.utils import max_logging
9494
from maxtext.utils import max_utils

src/maxtext/experimental/rl/grpo_input_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
import grain.python as grain
3434

35-
from MaxText.input_pipeline import input_pipeline_interface
36-
from MaxText.input_pipeline import _input_pipeline_utils
35+
from maxtext.input_pipeline import input_pipeline_interface
36+
from maxtext.input_pipeline import input_pipeline_utils
3737

3838

3939
class SingleHostDataLoader:
@@ -141,7 +141,7 @@ def preprocessing_pipeline(
141141
)
142142

143143
dataset = dataset.map(
144-
_input_pipeline_utils.tokenization,
144+
input_pipeline_utils.tokenization,
145145
batched=True,
146146
fn_kwargs={
147147
"hf_tokenizer": tokenizer,
@@ -151,7 +151,7 @@ def preprocessing_pipeline(
151151
},
152152
)
153153
dataset = dataset.select_columns(data_column_names)
154-
dataset = _input_pipeline_utils.HFDataSource(
154+
dataset = input_pipeline_utils.HFDataSource(
155155
dataset,
156156
dataloading_host_index,
157157
dataloading_host_count,
@@ -166,7 +166,7 @@ def lists2array(x):
166166

167167
operations = [
168168
grain.MapOperation(lists2array),
169-
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
169+
input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
170170
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
171171
]
172172

src/maxtext/inference/inference_microbenchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from collections.abc import MutableMapping
2424

2525
from MaxText import maxengine
26-
from MaxText import prefill_packing
2726
from MaxText import pyconfig
2827
from maxtext.common import profiler
28+
from maxtext.input_pipeline.packing import prefill_packing
2929
from maxtext.utils import gcs_utils
3030
from maxtext.utils import max_utils
3131
from maxtext.utils import maxtext_utils

src/maxtext/inference/mlperf/offline_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
# pylint: disable=no-name-in-module
3636
from MaxText.maxengine import MaxEngine
3737
from MaxText.maxengine import set_engine_vars_from_base_engine
38-
from MaxText.prefill_packing import PrefillProcessor
39-
from MaxText.prefill_packing import BatchedPrefillProcessor
38+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
39+
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor
4040

4141
DecodeState = Any
4242
Params = Any

src/maxtext/inference/offline_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from jax.experimental import mesh_utils
5555

5656
from MaxText.maxengine import MaxEngine
57-
from MaxText.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
57+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
58+
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor
5859
from maxtext.utils import max_logging
5960
from maxtext.utils import max_utils
6061

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)