Skip to content

Commit 9ae7e45

Browse files
committed
restructure the input pipeline folder
1 parent 3bc185f commit 9ae7e45

39 files changed

Lines changed: 1194 additions & 1275 deletions

src/MaxText/examples/demo_decoding.ipynb

Lines changed: 405 additions & 438 deletions
Large diffs are not rendered by default.

src/MaxText/examples/sft_train_and_evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@
8686

8787
from flax import nnx
8888

89-
from MaxText.globals import MAXTEXT_REPO_ROOT
9089
from MaxText import pyconfig
91-
from MaxText.input_pipeline import instruction_data_processing
90+
from MaxText.globals import MAXTEXT_REPO_ROOT
9291
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
92+
from maxtext.input_pipeline import instruction_data_processing
9393
from maxtext.trainers.post_train.sft import train_sft
9494
from maxtext.utils import max_logging
9595
from maxtext.utils import max_utils

src/MaxText/experimental/rl/grpo_input_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
import grain.python as grain
3636

37-
from MaxText.input_pipeline import input_pipeline_interface
38-
from MaxText.input_pipeline import _input_pipeline_utils
37+
from maxtext.input_pipeline import input_pipeline_interface
38+
from maxtext.input_pipeline import input_pipeline_utils
3939

4040

4141
class SingleHostDataLoader:
@@ -143,7 +143,7 @@ def preprocessing_pipeline(
143143
)
144144

145145
dataset = dataset.map(
146-
_input_pipeline_utils.tokenization,
146+
input_pipeline_utils.tokenization,
147147
batched=True,
148148
fn_kwargs={
149149
"hf_tokenizer": tokenizer,
@@ -153,7 +153,7 @@ def preprocessing_pipeline(
153153
},
154154
)
155155
dataset = dataset.select_columns(data_column_names)
156-
dataset = _input_pipeline_utils.HFDataSource(
156+
dataset = input_pipeline_utils.HFDataSource(
157157
dataset,
158158
dataloading_host_index,
159159
dataloading_host_count,
@@ -168,7 +168,7 @@ def lists2array(x):
168168

169169
operations = [
170170
grain.MapOperation(lists2array),
171-
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
171+
input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
172172
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
173173
]
174174

src/MaxText/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7878
from MaxText.rl.evaluate_rl import evaluate
7979
from MaxText.rl import utils_rl
80-
from MaxText.input_pipeline.instruction_data_processing import load_template_from_file
80+
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
8181
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
8282

8383

src/maxtext/common/checkpointing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from flax.training import train_state
2323
import jax
2424
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
25-
from MaxText.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
26-
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
25+
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator
26+
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2727
from maxtext.utils import exceptions
2828
from maxtext.utils import max_logging
2929
import numpy as np

src/maxtext/inference/inference_microbenchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from collections.abc import MutableMapping
2424

2525
from MaxText import maxengine
26-
from MaxText import prefill_packing
2726
from MaxText import pyconfig
2827
from maxtext.common import profiler
28+
from maxtext.input_pipeline.packing import prefill_packing
2929
from maxtext.utils import gcs_utils
3030
from maxtext.utils import max_utils
3131
from maxtext.utils import maxtext_utils

src/maxtext/inference/mlperf/offline_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
# pylint: disable=no-name-in-module
3636
from MaxText.maxengine import MaxEngine
3737
from MaxText.maxengine import set_engine_vars_from_base_engine
38-
from MaxText.prefill_packing import PrefillProcessor
39-
from MaxText.prefill_packing import BatchedPrefillProcessor
38+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor
39+
from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor
4040

4141
DecodeState = Any
4242
Params = Any

src/maxtext/inference/offline_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from jax.experimental import mesh_utils
5555

5656
from MaxText.maxengine import MaxEngine
57-
from MaxText.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
57+
from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor, BatchedPrefillProcessor
5858
from maxtext.utils import max_logging
5959
from maxtext.utils import max_utils
6060

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023-2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/input_pipeline/_distillation_data_processing.py renamed to src/maxtext/input_pipeline/distillation_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import datasets
2727

28-
from MaxText.input_pipeline import _input_pipeline_utils
28+
from maxtext.input_pipeline import input_pipeline_utils
2929
from maxtext.utils import max_logging
3030

3131

@@ -85,7 +85,7 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name
8585
assert any(
8686
set(data_column_names) == set(supported) for supported in supported_columns
8787
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}"
88-
assert _input_pipeline_utils.is_conversational(
88+
assert input_pipeline_utils.is_conversational(
8989
dataset.features, data_column_names
9090
), "Dataset is not in conversational format."
9191

0 commit comments

Comments
 (0)