Skip to content

Commit a6bc42b

Browse files
author
Juan Acevedo
committed
lint
1 parent afc8882 commit a6bc42b

9 files changed

Lines changed: 336 additions & 334 deletions

File tree

src/maxdiffusion/data_preprocessing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
15-
"""
15+
"""

src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import tensorflow as tf
3636

37+
3738
def image_feature(value):
3839
"""Returns a bytes_list from a string / byte."""
3940
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
@@ -58,6 +59,7 @@ def float_feature_list(value):
5859
"""Returns a list of float_list from a float / double."""
5960
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
6061

62+
6163
def create_example(latent, hidden_states):
6264
latent = tf.io.serialize_tensor(latent)
6365
hidden_states = tf.io.serialize_tensor(hidden_states)
@@ -74,11 +76,13 @@ def text_encode(pipeline, prompt: Union[str, List[str]]):
7476
encoder_hidden_states = encoder_hidden_states.detach().numpy()
7577
return encoder_hidden_states
7678

79+
7780
def vae_encode(video, rng, vae, vae_cache):
7881
latent = vae.encode(video, feat_cache=vae_cache)
7982
latent = latent.latent_dist.sample(rng)
8083
return latent
81-
84+
85+
8286
def generate_dataset(config, pipeline):
8387

8488
tfrecords_dir = config.tfrecords_dir
@@ -99,21 +103,21 @@ def generate_dataset(config, pipeline):
99103
rng = jax.random.key(config.seed)
100104

101105
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
102-
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
103-
106+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
107+
104108
# jit vae fun.
105109
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
106-
110+
107111
# Load dataset
108-
ds = load_dataset(config.dataset_name, split='train')
112+
ds = load_dataset(config.dataset_name, split="train")
109113
ds = ds.shuffle(seed=config.seed)
110114
ds = ds.select_columns([config.caption_column, config.image_column])
111115
batch_size = 10
112116
for i in range(0, len(ds), batch_size):
113117
rng, new_rng = jax.random.split(rng)
114-
text = ds[i:i+batch_size]['text']
115-
videos = ds[i:i+batch_size]['image']
116-
118+
text = ds[i : i + batch_size]["text"]
119+
videos = ds[i : i + batch_size]["image"]
120+
117121
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
118122
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
119123
with mesh:
@@ -127,24 +131,23 @@ def generate_dataset(config, pipeline):
127131

128132
if shard_record_count >= no_records_per_shard:
129133
writer.close()
130-
tf_rec_num +=1
134+
tf_rec_num += 1
131135
writer = tf.io.TFRecordWriter(
132136
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
133137
)
134138
shard_record_count = 0
135139

136140

137-
138141
def run(config):
139142
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
140143
# Don't need the transformer for preprocessing.
141144
generate_dataset(config, pipeline)
142145

143146

144-
145147
def main(argv: Sequence[str]) -> None:
146148
pyconfig.initialize(argv)
147149
run(pyconfig.config)
148150

151+
149152
if __name__ == "__main__":
150-
app.run(main)
153+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from maxdiffusion.utils import export_to_video
2222

2323

24-
def run(config, pipeline=None, filename_prefix=''):
24+
def run(config, pipeline=None, filename_prefix=""):
2525
print("seed: ", config.seed)
2626
if pipeline is None:
2727
pipeline = WanPipeline.from_pretrained(config)

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,9 @@ def make_tf_iterator(
7373
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7474
return train_iter
7575

76+
7677
def make_cached_tfrecord_iterator(
77-
config,
78-
dataloading_host_index,
79-
dataloading_host_count,
80-
mesh,
81-
global_batch_size,
82-
feature_description,
83-
prepare_sample_fn
78+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
8479
):
8580
"""
8681
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
@@ -111,13 +106,7 @@ def _parse_tfrecord_fn(example):
111106

112107
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
113108
def make_tfrecord_iterator(
114-
config,
115-
dataloading_host_index,
116-
dataloading_host_count,
117-
mesh,
118-
global_batch_size,
119-
feature_description,
120-
prepare_sample_fn
109+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
121110
):
122111
"""Iterator for TFRecord format. For Laion dataset,
123112
check out preparation script
@@ -127,18 +116,20 @@ def make_tfrecord_iterator(
127116
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
128117
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
129118
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
130-
if (config.cache_latents_text_encoder_outputs
119+
if (
120+
config.cache_latents_text_encoder_outputs
131121
and os.path.isdir(config.dataset_save_location)
132-
and 'load_tfrecord_cached'in config.get_keys()
133-
and config.load_tfrecord_cached):
122+
and "load_tfrecord_cached" in config.get_keys()
123+
and config.load_tfrecord_cached
124+
):
134125
return make_cached_tfrecord_iterator(
135-
config,
136-
dataloading_host_index,
137-
dataloading_host_count,
138-
mesh,
139-
global_batch_size,
140-
feature_description,
141-
prepare_sample_fn
126+
config,
127+
dataloading_host_index,
128+
dataloading_host_count,
129+
mesh,
130+
global_batch_size,
131+
feature_description,
132+
prepare_sample_fn,
142133
)
143134

144135
feature_description = {

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,23 @@ def make_data_iterator(
5151
tokenize_fn=None,
5252
image_transforms_fn=None,
5353
feature_description=None,
54-
prepare_sample_fn=None
54+
prepare_sample_fn=None,
5555
):
5656
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
57-
57+
5858
if config.dataset_type == "hf" or config.dataset_type == "tf":
5959
if tokenize_fn is None or image_transforms_fn is None:
6060
raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn")
61-
62-
if config.dataset_type == "tfrecord" and config.cache_latents_text_encoder_outputs and feature_description is None or prepare_sample_fn is None:
63-
raise ValueError(f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True.")
61+
62+
if (
63+
config.dataset_type == "tfrecord"
64+
and config.cache_latents_text_encoder_outputs
65+
and feature_description is None
66+
or prepare_sample_fn is None
67+
):
68+
raise ValueError(
69+
f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True."
70+
)
6471

6572
if config.dataset_type == "hf":
6673
return _hf_data_processing.make_hf_streaming_iterator(
@@ -98,7 +105,7 @@ def make_data_iterator(
98105
mesh,
99106
global_batch_size,
100107
feature_description,
101-
prepare_sample_fn
108+
prepare_sample_fn,
102109
)
103110
else:
104111
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
def basic_clean(text):
4242
if is_ftfy_available():
4343
import ftfy
44+
4445
text = ftfy.fix_text(text)
4546
text = html.unescape(html.unescape(text))
4647
return text.strip()
@@ -398,7 +399,7 @@ def __call__(
398399
num_channels_latents=num_channel_latents,
399400
)
400401

401-
data_sharding = NamedSharding(self.devices_array, P())
402+
data_sharding = NamedSharding(self.mesh, P())
402403
if len(prompt) % jax.device_count() == 0:
403404
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
404405

0 commit comments

Comments
 (0)