Skip to content

Commit 2c3c1a2

Browse files
Merge pull request #310 from jianhan-amd:jianhan/synthetic_data_iterator
PiperOrigin-RevId: 877546452
2 parents 85ba65e + 0c92bfb commit 2c3c1a2

8 files changed

Lines changed: 712 additions & 6 deletions

File tree

README.md

100644100755
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ After installation completes, run the training script.
262262
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
263263
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
264264
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.
265+
- For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported.
266+
- Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator.
267+
- The following overrides on data dimensions are supported:
268+
- synthetic_override_height: 720
269+
- synthetic_override_width: 1280
270+
- synthetic_override_num_frames: 85
271+
- synthetic_override_max_sequence_length: 512
272+
- synthetic_override_text_embed_dim: 4096
273+
- synthetic_override_num_channels_latents: 16
274+
- synthetic_override_vae_scale_factor_spatial: 8
275+
- synthetic_override_vae_scale_factor_temporal: 4
265276

266277
You should eventually see a training run as:
267278

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,20 @@ allow_split_physical_axes: False
179179
# Replace with dataset path or train_data_dir. One has to be set.
180180
dataset_name: 'diffusers/pokemon-gpt4-captions'
181181
train_split: 'train'
182-
dataset_type: 'tf'
182+
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
183+
# ==============================================================================
184+
# Synthetic Data Configuration (only used when dataset_type='synthetic')
185+
# ==============================================================================
186+
# To use synthetic data for testing/debugging without real datasets:
187+
# 1. Set dataset_type: 'synthetic' above
188+
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
189+
# 3. Optionally override dimensions
190+
#
191+
# synthetic_num_samples: null # null for infinite, or set a number
192+
#
193+
# Optional dimension overrides:
194+
# resolution: 512
195+
# ==============================================================================
183196
cache_latents_text_encoder_outputs: True
184197
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
185198
# only apply to small dataset that fits in memory

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,28 @@ allow_split_physical_axes: False
201201
# Replace with dataset path or train_data_dir. One has to be set.
202202
dataset_name: 'diffusers/pokemon-gpt4-captions'
203203
train_split: 'train'
204-
dataset_type: 'tfrecord'
204+
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
205+
# ==============================================================================
206+
# Synthetic Data Configuration (only used when dataset_type='synthetic')
207+
# ==============================================================================
208+
# To use synthetic data for testing/debugging without real datasets:
209+
# 1. Set dataset_type: 'synthetic' above
210+
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
211+
# 3. Optionally override dimensions with synthetic_override_* flags below
212+
#
213+
# synthetic_num_samples: null # null for infinite, or set a number
214+
#
215+
# Optional dimension overrides (comment out to use pipeline/config values):
216+
# synthetic_override_height: 720
217+
# synthetic_override_width: 1280
218+
# synthetic_override_num_frames: 121
219+
# synthetic_override_max_sequence_length: 512
220+
# synthetic_override_text_embed_dim: 4096
221+
# synthetic_override_num_channels_latents: 16
222+
# synthetic_override_vae_scale_factor_spatial: 8
223+
# synthetic_override_vae_scale_factor_temporal: 4
224+
# ==============================================================================
225+
205226
cache_latents_text_encoder_outputs: True
206227
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
207228
# only apply to small dataset that fits in memory

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from maxdiffusion.input_pipeline import _hf_data_processing
2424
from maxdiffusion.input_pipeline import _grain_data_processing
2525
from maxdiffusion.input_pipeline import _tfds_data_processing
26+
from maxdiffusion.input_pipeline import synthetic_data_iterator
2627
from maxdiffusion import multihost_dataloading
2728
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
2829
from maxdiffusion.dreambooth.dreambooth_constants import (
@@ -54,8 +55,9 @@ def make_data_iterator(
5455
feature_description=None,
5556
prepare_sample_fn=None,
5657
is_training=True,
58+
pipeline=None,
5759
):
58-
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
60+
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)"""
5961

6062
if config.dataset_type == "hf" or config.dataset_type == "tf":
6163
if tokenize_fn is None or image_transforms_fn is None:
@@ -110,8 +112,16 @@ def make_data_iterator(
110112
prepare_sample_fn,
111113
is_training,
112114
)
115+
elif config.dataset_type == "synthetic":
116+
return synthetic_data_iterator.make_synthetic_iterator(
117+
config=config,
118+
mesh=mesh,
119+
global_batch_size=global_batch_size,
120+
pipeline=pipeline,
121+
is_training=is_training,
122+
)
113123
else:
114-
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
124+
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)"
115125

116126

117127
def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):

0 commit comments

Comments
 (0)