Skip to content

Commit 9823136

Browse files
committed
restructure input_pipeline
1 parent b4fd8ac commit 9823136

39 files changed

Lines changed: 171 additions & 163 deletions

src/MaxText/layers/engram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from jax.sharding import Mesh
3030
from flax import nnx
3131

32-
from MaxText.tokenizer import HFTokenizer
32+
from maxtext.input_pipeline.tokenizer import HFTokenizer
3333
from MaxText.common_types import MODEL_MODE_TRAIN, Array, Config
3434
from MaxText.layers.embeddings import Embed
3535
from MaxText.layers.initializers import nd_dense_init, NdInitializer

src/MaxText/rl/train_rl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7878
from MaxText.rl.evaluate_rl import evaluate
7979
from MaxText.rl import utils_rl
80-
from MaxText.input_pipeline.instruction_data_processing import load_template_from_file
80+
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
8181
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
8282

8383

@@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
370370
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
371371
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
372372

373-
374373
if trainer_config.debug.rl:
375374
max_logging.log("Policy Model initialized successfully")
376375
nnx.display(actor_model)

src/maxtext/common/checkpointing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from flax.training import train_state
2424
import jax
2525
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
26-
from MaxText.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
27-
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
26+
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
27+
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2828
from maxtext.utils import exceptions
2929
from maxtext.utils import max_logging
3030
import numpy as np

src/maxtext/examples/sft_train_and_evaluate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,12 @@
8585

8686
from flax import nnx
8787

88-
from MaxText.globals import MAXTEXT_REPO_ROOT
8988
from MaxText import pyconfig
90-
from MaxText.input_pipeline import instruction_data_processing
89+
from MaxText.globals import MAXTEXT_REPO_ROOT
9190
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
91+
from maxtext.input_pipeline import instruction_data_processing
9292
from maxtext.trainers.post_train.sft import train_sft
93-
from maxtext.utils import max_logging
94-
from maxtext.utils import max_utils
93+
from maxtext.utils import max_logging, max_utils
9594

9695
# Suppress vLLM logging with a severity level below ERROR
9796
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"

src/maxtext/experimental/rl/grpo_input_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
import grain.python as grain
3434

35-
from MaxText.input_pipeline import input_pipeline_interface
36-
from MaxText.input_pipeline import _input_pipeline_utils
35+
from maxtext.input_pipeline import input_pipeline_interface
36+
from maxtext.input_pipeline import input_pipeline_utils
3737

3838

3939
class SingleHostDataLoader:
@@ -141,7 +141,7 @@ def preprocessing_pipeline(
141141
)
142142

143143
dataset = dataset.map(
144-
_input_pipeline_utils.tokenization,
144+
input_pipeline_utils.tokenization,
145145
batched=True,
146146
fn_kwargs={
147147
"hf_tokenizer": tokenizer,
@@ -151,7 +151,7 @@ def preprocessing_pipeline(
151151
},
152152
)
153153
dataset = dataset.select_columns(data_column_names)
154-
dataset = _input_pipeline_utils.HFDataSource(
154+
dataset = input_pipeline_utils.HFDataSource(
155155
dataset,
156156
dataloading_host_index,
157157
dataloading_host_count,
@@ -166,7 +166,7 @@ def lists2array(x):
166166

167167
operations = [
168168
grain.MapOperation(lists2array),
169-
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
169+
input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
170170
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
171171
]
172172

src/maxtext/inference/inference_microbenchmark.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,10 @@
2222
from absl import app
2323
from collections.abc import MutableMapping
2424

25-
from MaxText import maxengine
26-
from MaxText import prefill_packing
27-
from MaxText import pyconfig
25+
from MaxText import maxengine, pyconfig
2826
from maxtext.common import profiler
29-
from maxtext.utils import gcs_utils
30-
from maxtext.utils import max_utils
31-
from maxtext.utils import maxtext_utils
27+
from maxtext.input_pipeline.packing import prefill_packing
28+
from maxtext.utils import gcs_utils, max_utils, maxtext_utils
3229

3330
import warnings
3431

src/maxtext/inference/mlperf/offline_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
# pylint: disable=no-name-in-module
3636
from MaxText.maxengine import MaxEngine
3737
from MaxText.maxengine import set_engine_vars_from_base_engine
38-
from MaxText.prefill_packing import PrefillProcessor
39-
from MaxText.prefill_packing import BatchedPrefillProcessor
38+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
39+
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor
4040

4141
DecodeState = Any
4242
Params = Any

src/maxtext/inference/offline_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from jax.experimental import mesh_utils
5555

5656
from MaxText.maxengine import MaxEngine
57-
from MaxText.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
57+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
5858
from maxtext.utils import max_logging
5959
from maxtext.utils import max_utils
6060

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/input_pipeline/_distillation_data_processing.py renamed to src/maxtext/input_pipeline/distillation_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from dataclasses import dataclass, field
2525

26-
from MaxText.input_pipeline import _input_pipeline_utils
26+
from maxtext.input_pipeline import input_pipeline_utils
2727
from maxtext.utils import max_logging
2828

2929

@@ -83,7 +83,7 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name
8383
assert any(
8484
set(data_column_names) == set(supported) for supported in supported_columns
8585
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}"
86-
assert _input_pipeline_utils.is_conversational(
86+
assert input_pipeline_utils.is_conversational(
8787
dataset.features, data_column_names
8888
), "Dataset is not in conversational format."
8989

0 commit comments

Comments
 (0)