Skip to content

Commit e53ee2b

Browse files
author
Juan Acevedo
committed
fix input pipeline tests
1 parent a6bc42b commit e53ee2b

2 files changed

Lines changed: 18 additions & 2 deletions

File tree

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def make_data_iterator(
6363
config.dataset_type == "tfrecord"
6464
and config.cache_latents_text_encoder_outputs
6565
and feature_description is None
66-
or prepare_sample_fn is None
66+
and prepare_sample_fn is None
6767
):
6868
raise ValueError(
6969
f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True."

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,23 @@ def test_make_laion_tfrecord_iterator(self):
506506
from_pt=config.from_pt,
507507
)
508508

509-
train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size)
509+
feature_description = {
510+
"moments": tf.io.FixedLenFeature([], tf.string),
511+
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
512+
}
513+
514+
def _parse_tfrecord_fn(example):
515+
return tf.io.parse_single_example(example, feature_description)
516+
517+
train_iterator = make_data_iterator(
518+
config,
519+
jax.process_index(),
520+
jax.process_count(),
521+
mesh,
522+
global_batch_size,
523+
feature_description=feature_description,
524+
prepare_sample_fn=_parse_tfrecord_fn,
525+
)
510526
data = next(train_iterator)
511527
device_count = jax.device_count()
512528

0 commit comments

Comments
 (0)