Skip to content

Commit 570ee04

Browse files
Merge pull request #3117 from AI-Hypercomputer:aireen/local_import
PiperOrigin-RevId: 867841433
2 parents 843f1f3 + 87038ef commit 570ee04

7 files changed

Lines changed: 33 additions & 14 deletions

File tree

src/MaxText/experimental/rl/grpo_input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
import jax
2929
from jax.sharding import Mesh
3030

31-
import datasets
32-
3331
import transformers
3432

3533
import grain.python as grain
@@ -217,6 +215,8 @@ def make_hf_train_iterator(
217215
Returns:
218216
A local data iterator for the training set.
219217
"""
218+
import datasets # pylint: disable=import-outside-toplevel
219+
220220
train_ds = datasets.load_dataset(
221221
config.hf_path,
222222
data_dir=config.hf_data_dir,

src/MaxText/input_pipeline/_distillation_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
from dataclasses import dataclass, field
2525

26-
import datasets
27-
2826
from MaxText.input_pipeline import _input_pipeline_utils
2927
from maxtext.utils import max_logging
3028

@@ -101,6 +99,8 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name
10199

102100
def load_dataset(config): # pylint: disable=redefined-outer-name
103101
"""Loads dataset from Hugging Face."""
102+
import datasets # pylint: disable=import-outside-toplevel
103+
104104
assert config.dataset_type == "huggingface", "Only dataset from Hugging Face is supported."
105105

106106
return datasets.load_dataset(

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import jax
2020

21-
import datasets
22-
2321
import transformers
2422

2523
import grain.python as grain
@@ -62,6 +60,8 @@ def vision_sft_preprocessing_pipeline(
6260
batch_size = global_batch_size // jax.process_count()
6361

6462
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
63+
import datasets # pylint: disable=import-outside-toplevel
64+
6565
if config.enable_data_shuffling and config.num_epoch > 1:
6666
epoch_datasets = [dataset.shuffle(seed=config.data_shuffle_seed + i) for i in range(config.num_epoch)]
6767
dataset = datasets.concatenate_datasets(epoch_datasets)
@@ -215,6 +215,7 @@ def preprocessing_pipeline(
215215
num_epoch=1,
216216
):
217217
"""pipeline for preprocessing HF dataset"""
218+
import datasets # pylint: disable=import-outside-toplevel
218219

219220
assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible by number of global devices."
220221
# Tunix GA requires per-micro-batch slicing at the data level,
@@ -377,6 +378,8 @@ def make_hf_train_iterator(
377378
process_indices_train,
378379
):
379380
"""Load, preprocess dataset and return iterators"""
381+
import datasets # pylint: disable=import-outside-toplevel
382+
380383
train_ds = datasets.load_dataset(
381384
config.hf_path,
382385
name=config.hf_name,
@@ -433,6 +436,8 @@ def make_hf_eval_iterator(
433436
process_indices_eval,
434437
):
435438
"""Make Hugging Face evaluation iterator. Load and preprocess eval dataset: and return iterator."""
439+
import datasets # pylint: disable=import-outside-toplevel
440+
436441
eval_ds = datasets.load_dataset(
437442
config.hf_path,
438443
name=config.hf_name,

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import dataclasses
1818
import warnings
1919
from threading import current_thread
20-
from typing import Any
21-
import datasets
22-
from datasets.distributed import split_dataset_by_node
20+
from typing import Any, TYPE_CHECKING
21+
22+
if TYPE_CHECKING:
23+
import datasets
24+
2325
import grain.python as grain
2426
import numpy as np
2527
import tensorflow as tf
@@ -145,6 +147,8 @@ def is_conversational(features, data_columns):
145147
data_columns = ["prompt", "completion"]
146148
is_conversational(features, data_columns) returns False.
147149
"""
150+
import datasets # pylint: disable=import-outside-toplevel
151+
148152
for column in data_columns:
149153
messages = features[column]
150154
if isinstance(messages, datasets.Sequence):
@@ -293,13 +297,16 @@ class HFDataSource(grain.RandomAccessDataSource):
293297

294298
def __init__(
295299
self,
296-
dataset: datasets.IterableDataset,
300+
dataset: "datasets.IterableDataset",
297301
dataloading_host_index: int,
298302
dataloading_host_count: int,
299303
num_threads: int,
300304
max_target_length: int,
301305
data_column_names: list[str],
302306
):
307+
from datasets.distributed import split_dataset_by_node # pylint: disable=import-outside-toplevel
308+
309+
self._split_dataset_by_node = split_dataset_by_node
303310
self.dataset = dataset
304311
self.num_threads = num_threads
305312
self.dataloading_host_count = dataloading_host_count
@@ -312,7 +319,7 @@ def __init__(
312319
self.n_shards = 1
313320
self._check_shard_count()
314321
self.dataset_shards = [dataloading_host_index * self.num_threads + i for i in range(self.num_threads)]
315-
self.datasets = [split_dataset_by_node(dataset, world_size=self.n_shards, rank=x) for x in self.dataset_shards]
322+
self.datasets = [self._split_dataset_by_node(dataset, world_size=self.n_shards, rank=x) for x in self.dataset_shards]
316323
self.data_iters = []
317324

318325
def _check_shard_count(self):
@@ -333,7 +340,9 @@ def _update_shard(self, idx):
333340
)
334341
max_logging.log(f"New shard is {new_shard}")
335342
self.dataset_shards[idx] = new_shard
336-
self.datasets[idx] = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx])
343+
self.datasets[idx] = self._split_dataset_by_node(
344+
self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx]
345+
)
337346
self.data_iters[idx] = iter(self.datasets[idx])
338347
else:
339348
raise StopIteration(f"Run out of shards on host {self.dataloading_host_index}, shard {new_shard} is not available")

src/MaxText/input_pipeline/instruction_data_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Preprocessing for instruction dataset."""
1616

17-
import datasets
1817
import json
1918
import os
2019
import re
@@ -117,6 +116,8 @@ def convert_to_conversational_format(
117116
chat_template_path,
118117
):
119118
"""Converts instruction dataset to conversational format."""
119+
import datasets # pylint: disable=import-outside-toplevel
120+
120121
template_config = None
121122
if chat_template_path:
122123
template_config = load_template_from_file(chat_template_path)

src/MaxText/pyconfig.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from typing import Any
2121
import copy
2222

23+
# Disable dill to avoid conflict with gfile (dill requires buffering=0, which gfile forbids)
24+
os.environ["HF_DATASETS_DISABLE_DILL"] = "1"
25+
2326
import jax
2427
import jax.numpy as jnp
2528

src/maxtext/examples/sft_train_and_evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
from tqdm.auto import tqdm
7979
from typing import Sequence
8080

81-
import datasets
8281
import grain
8382
import os
8483
import re
@@ -140,6 +139,8 @@ def get_test_dataset(config, tokenizer):
140139
A grain.MapDataset instance for the test split, with prompts and target
141140
answers.
142141
"""
142+
import datasets # pylint: disable=import-outside-toplevel
143+
143144
template_config = instruction_data_processing.load_template_from_file(config.chat_template_path)
144145
dataset = datasets.load_dataset(
145146
DATASET_NAME,

0 commit comments

Comments
 (0)