|
23 | 23 | from maxdiffusion.input_pipeline import _hf_data_processing |
24 | 24 | from maxdiffusion.input_pipeline import _grain_data_processing |
25 | 25 | from maxdiffusion.input_pipeline import _tfds_data_processing |
| 26 | +from maxdiffusion.input_pipeline import synthetic_data_iterator |
26 | 27 | from maxdiffusion import multihost_dataloading |
27 | 28 | from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply |
28 | 29 | from maxdiffusion.dreambooth.dreambooth_constants import ( |
@@ -54,8 +55,9 @@ def make_data_iterator( |
54 | 55 | feature_description=None, |
55 | 56 | prepare_sample_fn=None, |
56 | 57 | is_training=True, |
| 58 | + pipeline=None, |
57 | 59 | ): |
58 | | - """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" |
| 60 | + """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)""" |
59 | 61 |
|
60 | 62 | if config.dataset_type == "hf" or config.dataset_type == "tf": |
61 | 63 | if tokenize_fn is None or image_transforms_fn is None: |
@@ -110,8 +112,16 @@ def make_data_iterator( |
110 | 112 | prepare_sample_fn, |
111 | 113 | is_training, |
112 | 114 | ) |
| 115 | + elif config.dataset_type == "synthetic": |
| 116 | + return synthetic_data_iterator.make_synthetic_iterator( |
| 117 | + config=config, |
| 118 | + mesh=mesh, |
| 119 | + global_batch_size=global_batch_size, |
| 120 | + pipeline=pipeline, |
| 121 | + is_training=is_training, |
| 122 | + ) |
113 | 123 | else: |
114 | | - assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" |
| 124 | + assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)" |
115 | 125 |
|
116 | 126 |
|
117 | 127 | def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params): |
|
0 commit comments