Skip to content

Commit 82e37c1

Browse files
committed
refactor tfrecord function and change eval loss name
1 parent 5ec6452 commit 82e37c1

2 files changed

Lines changed: 20 additions & 79 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -78,105 +78,45 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81-
def make_cached_tfrecord_iterator(
82-
dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training: bool
83-
):
84-
"""
85-
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
86-
latents, input_ids, prompt_embeds, and text_embeds.
87-
"""
88-
89-
def _parse_tfrecord_fn(example):
90-
return tf.io.parse_single_example(example, feature_description)
91-
92-
# This pipeline reads the sharded files and applies the parsing and preparation.
93-
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
94-
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
95-
96-
# --- PADDING LOGIC FOR EVALUATION ---
97-
if not is_training:
98-
num_eval_samples = 0
99-
for _ in ds:
100-
num_eval_samples += 1
101-
102-
remainder = num_eval_samples % global_batch_size
103-
if remainder != 0:
104-
num_to_pad = global_batch_size - remainder
105-
# Create a dataset of padding samples from the beginning
106-
padding_ds = ds.take(num_to_pad)
107-
# Add the padding samples to the end
108-
ds = ds.concatenate(padding_ds)
109-
print(f"Padded evaluation dataset with {num_to_pad} samples.")
110-
111-
ds = (
112-
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
113-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
114-
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
115-
)
116-
if is_training:
117-
ds = (
118-
ds.shuffle(global_batch_size * 10)
119-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
120-
.repeat(-1)
121-
.prefetch(AUTOTUNE)
122-
)
123-
# For Evaluation
124-
else:
125-
ds = (
126-
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
127-
.prefetch(AUTOTUNE)
128-
)
129-
130-
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
131-
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
132-
return iter
133-
134-
13581
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
13682
def _make_tfrecord_iterator(
137-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training: bool
83+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
13884
):
13985
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
14086
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
14187
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
88+
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
14289

14390
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
144-
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
14591
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
14692

147-
if (
148-
config.cache_latents_text_encoder_outputs
149-
and is_dataset_dir_valid
150-
and "load_tfrecord_cached" in config.get_keys()
151-
and config.load_tfrecord_cached
152-
):
153-
return make_cached_tfrecord_iterator(
154-
dataloading_host_index,
155-
dataloading_host_count,
156-
mesh,
157-
global_batch_size,
158-
feature_description,
159-
prepare_sample_fn,
160-
dataset_path,
161-
is_training
162-
)
93+
# Determine whether to use the "cached" dataset, which requires externally
94+
# provided parsing functions, or the default one with its internal parsing logic.
95+
make_cached_tfrecord_iterator = (
96+
config.cache_latents_text_encoder_outputs
97+
and is_dataset_dir_valid
98+
and "load_tfrecord_cached" in config.get_keys()
99+
and config.load_tfrecord_cached
100+
)
163101

164102
feature_description = {
165103
"moments": tf.io.FixedLenFeature([], tf.string),
166104
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
167105
}
168106

107+
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
108+
169109
def _parse_tfrecord_fn(example):
170-
return tf.io.parse_single_example(example, feature_description)
110+
return tf.io.parse_single_example(example, used_feature_description)
171111

172112
def prepare_sample(features):
173113
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
174114
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
175115
return {"pixel_values": moments, "input_ids": clip_embeddings}
176116

177117
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
178-
179118
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
119+
180120
# --- PADDING LOGIC FOR EVALUATION ---
181121
if not is_training:
182122
num_eval_samples = 0
@@ -191,11 +131,12 @@ def prepare_sample(features):
191131
# Add the padding samples to the end
192132
ds = ds.concatenate(padding_ds)
193133
print(f"Padded evaluation dataset with {num_to_pad} samples.")
194-
134+
135+
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
195136
ds = (
196137
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
197138
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
198-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
139+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
199140
)
200141
if is_training:
201142
ds = (

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
251251
with mesh:
252252
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
253253
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
254-
eval_metrics.append(metrics["scalar"]["eval/loss"])
254+
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
255255
except StopIteration:
256256
# This block is executed when the iterator has no more data
257257
break
@@ -260,7 +260,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
260260
eval_loss = jnp.mean(jnp.array(eval_metrics))
261261
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
262262
if writer:
263-
writer.add_scalar("eval/loss", eval_loss, step)
263+
writer.add_scalar("learning/eval_loss", eval_loss, step)
264264
else:
265265
max_logging.log(f"Step {step}, evaluation dataset was empty.")
266266
example_batch = next_batch_future.result()
@@ -373,7 +373,7 @@ def loss_fn(params):
373373
loss = loss_fn(state.params)
374374

375375
# Structure the metrics for logging and aggregation
376-
metrics = {"scalar": {"eval/loss": loss}}
376+
metrics = {"scalar": {"learning/eval_loss": loss}}
377377

378378
# Return the computed metrics and the new RNG key for the next eval step
379379
return metrics, new_rng

0 commit comments

Comments
 (0)