Skip to content

Commit f70f5c8

Browse files
Merge pull request #3149 from AI-Hypercomputer:aireen/syn_np
PiperOrigin-RevId: 872121553
2 parents a89eb2a + c353668 commit f70f5c8

3 files changed

Lines changed: 18 additions & 27 deletions

File tree

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def pretrain_preprocessing_pipeline(
238238
# global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1.
239239
# But when using Grain, we want to keep the batch_size consistent with that in the checkpoint.
240240
# We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py.
241-
batch_size = batch_size // config.expansion_factor_real_data
241+
batch_size = int(batch_size // config.expansion_factor_real_data)
242242

243243
if config.packing:
244244
length_struct = {col: config.max_target_length for col in data_columns}

src/maxtext/input_pipeline/multihost_dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _get_next_batch_sharded(self) -> jax.Array:
125125
# expansion_loading_factor_for_grain times to get the
126126
# right batch_size for the host that is loading real data.
127127
local_data_list = [local_data]
128-
for _ in range(1, self.expansion_loading_factor_for_grain):
128+
for _ in range(1, int(self.expansion_loading_factor_for_grain)):
129129
next_batch = next(self.local_iterator)
130130
local_data_list.append(next_batch)
131131
local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list)

src/maxtext/input_pipeline/synthetic_data_processing.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import numpy as np
2121

22-
import tensorflow as tf
23-
2422
import jax
2523
import jax.numpy as jnp
2624
from jax.sharding import PartitionSpec as P
@@ -100,26 +98,19 @@ def reset(self):
10098
@staticmethod
10199
def get_place_holder_synthetic_data(config: pyconfig.HyperParameters):
102100
"""fill negative value in synthetic data"""
103-
output = {}
104-
output["inputs"] = tf.data.Dataset.from_tensor_slices(
105-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
106-
)
107-
output["inputs_position"] = tf.data.Dataset.from_tensor_slices(
108-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
109-
)
110-
output["inputs_segmentation"] = tf.data.Dataset.from_tensor_slices(
111-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
112-
)
113-
output["targets"] = tf.data.Dataset.from_tensor_slices(
114-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
115-
)
116-
output["targets_position"] = tf.data.Dataset.from_tensor_slices(
117-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
118-
)
119-
output["targets_segmentation"] = tf.data.Dataset.from_tensor_slices(
120-
np.full((1, config.max_target_length), -1, dtype=jax.numpy.int32)
121-
)
122-
dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types
123-
dataset = dataset.repeat()
124-
dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count())
125-
return dataset
101+
batch_size = config.global_batch_size_to_load // jax.process_count()
102+
neg_ones = np.full((batch_size, config.max_target_length), -1, dtype=np.int32)
103+
batch = {
104+
"inputs": neg_ones,
105+
"inputs_position": neg_ones,
106+
"inputs_segmentation": neg_ones,
107+
"targets": neg_ones,
108+
"targets_position": neg_ones,
109+
"targets_segmentation": neg_ones,
110+
}
111+
112+
def infinite_iterator():
113+
while True:
114+
yield batch
115+
116+
return infinite_iterator()

0 commit comments

Comments
 (0)