Skip to content

Commit 9b1e9df

Browse files
committed
feat: add general synthetic data iterator and examples for WAN and FLUX.
1 parent e8bdd82 commit 9b1e9df

6 files changed

Lines changed: 579 additions & 6 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,20 @@ allow_split_physical_axes: False
177177
# Replace with dataset path or train_data_dir. One has to be set.
178178
dataset_name: 'diffusers/pokemon-gpt4-captions'
179179
train_split: 'train'
180-
dataset_type: 'tf'
180+
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
181+
# ==============================================================================
182+
# Synthetic Data Configuration (only used when dataset_type='synthetic')
183+
# ==============================================================================
184+
# To use synthetic data for testing/debugging without real datasets:
185+
# 1. Set dataset_type: 'synthetic' above
186+
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
187+
# 3. Optionally override dimensions
188+
#
189+
# synthetic_num_samples: null # null for infinite, or set a number
190+
#
191+
# Optional dimension overrides:
192+
# resolution: 512
193+
# ==============================================================================
181194
cache_latents_text_encoder_outputs: True
182195
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
183196
# only apply to small dataset that fits in memory

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,28 @@ allow_split_physical_axes: False
199199
# Replace with dataset path or train_data_dir. One has to be set.
200200
dataset_name: 'diffusers/pokemon-gpt4-captions'
201201
train_split: 'train'
202-
dataset_type: 'tfrecord'
202+
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
203+
# ==============================================================================
204+
# Synthetic Data Configuration (only used when dataset_type='synthetic')
205+
# ==============================================================================
206+
# To use synthetic data for testing/debugging without real datasets:
207+
# 1. Set dataset_type: 'synthetic' above
208+
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
209+
# 3. Optionally override dimensions with synthetic_override_* flags below
210+
#
211+
# synthetic_num_samples: null # null for infinite, or set a number
212+
#
213+
# Optional dimension overrides (comment out to use pipeline/config values):
214+
# synthetic_override_height: 720
215+
# synthetic_override_width: 1280
216+
# synthetic_override_num_frames: 121
217+
# synthetic_override_max_sequence_length: 512
218+
# synthetic_override_text_embed_dim: 4096
219+
# synthetic_override_num_channels_latents: 16
220+
# synthetic_override_vae_scale_factor_spatial: 8
221+
# synthetic_override_vae_scale_factor_temporal: 4
222+
# ==============================================================================
223+
203224
cache_latents_text_encoder_outputs: True
204225
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
205226
# only apply to small dataset that fits in memory

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from maxdiffusion.input_pipeline import _hf_data_processing
2424
from maxdiffusion.input_pipeline import _grain_data_processing
2525
from maxdiffusion.input_pipeline import _tfds_data_processing
26+
from maxdiffusion.input_pipeline import synthetic_data_iterator
2627
from maxdiffusion import multihost_dataloading
2728
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
2829
from maxdiffusion.dreambooth.dreambooth_constants import (
@@ -54,8 +55,9 @@ def make_data_iterator(
5455
feature_description=None,
5556
prepare_sample_fn=None,
5657
is_training=True,
58+
pipeline=None,
5759
):
58-
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
60+
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)"""
5961

6062
if config.dataset_type == "hf" or config.dataset_type == "tf":
6163
if tokenize_fn is None or image_transforms_fn is None:
@@ -110,8 +112,16 @@ def make_data_iterator(
110112
prepare_sample_fn,
111113
is_training,
112114
)
115+
elif config.dataset_type == "synthetic":
116+
return synthetic_data_iterator.make_synthetic_iterator(
117+
config=config,
118+
mesh=mesh,
119+
global_batch_size=global_batch_size,
120+
pipeline=pipeline,
121+
is_training=is_training,
122+
)
113123
else:
114-
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
124+
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)"
115125

116126

117127
def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):

0 commit comments

Comments
 (0)