1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ """
18+ Prepare tfrecords with latents and text embeddings preprocessed.
19+ 1. Download the dataset
20+ """
21+
22+ import os
23+ import functools
24+ from absl import app
25+ from typing import Sequence , Union , List
26+ from datasets import load_dataset
27+ import numpy as np
28+ import jax
29+ import jax .numpy as jnp
30+ from jax .sharding import Mesh
31+ from maxdiffusion import pyconfig , max_utils
32+ from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
33+ from maxdiffusion .video_processor import VideoProcessor
34+
35+ import tensorflow as tf
36+
37+ def image_feature (value ):
38+ """Returns a bytes_list from a string / byte."""
39+ return tf .train .Feature (bytes_list = tf .train .BytesList (value = [tf .io .encode_jpeg (value ).numpy ()]))
40+
41+
42+ def bytes_feature (value ):
43+ """Returns a bytes_list from a string / byte."""
44+ return tf .train .Feature (bytes_list = tf .train .BytesList (value = [value .numpy ()]))
45+
46+
47+ def float_feature (value ):
48+ """Returns a float_list from a float / double."""
49+ return tf .train .Feature (float_list = tf .train .FloatList (value = [value ]))
50+
51+
52+ def int64_feature (value ):
53+ """Returns an int64_list from a bool / enum / int / uint."""
54+ return tf .train .Feature (int64_list = tf .train .Int64List (value = [value ]))
55+
56+
57+ def float_feature_list (value ):
58+ """Returns a list of float_list from a float / double."""
59+ return tf .train .Feature (float_list = tf .train .FloatList (value = value ))
60+
61+ def create_example (latent , hidden_states ):
62+ latent = tf .io .serialize_tensor (latent )
63+ hidden_states = tf .io .serialize_tensor (hidden_states )
64+ feature = {
65+ "latents" : bytes_feature (latent ),
66+ "encoder_hidden_states" : bytes_feature (hidden_states ),
67+ }
68+ example = tf .train .Example (features = tf .train .Features (feature = feature ))
69+ return example .SerializeToString ()
70+
71+
72+ def text_encode (pipeline , prompt : Union [str , List [str ]]):
73+ encoder_hidden_states = pipeline ._get_t5_prompt_embeds (prompt )
74+ encoder_hidden_states = encoder_hidden_states .detach ().numpy ()
75+ return encoder_hidden_states
76+
77+ def vae_encode (video , rng , vae , vae_cache ):
78+ latent = vae .encode (video , feat_cache = vae_cache )
79+ latent = latent .latent_dist .sample (rng )
80+ return latent
81+
82+ def generate_dataset (config , pipeline ):
83+
84+ tfrecords_dir = config .tfrecords_dir
85+ if not os .path .exists (tfrecords_dir ):
86+ os .makedirs (tfrecords_dir )
87+
88+ tf_rec_num = 0
89+ no_records_per_shard = config .no_records_per_shard
90+ global_record_count = 0
91+ writer = tf .io .TFRecordWriter (
92+ tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num , (global_record_count + no_records_per_shard ))
93+ )
94+ shard_record_count = 0
95+
96+ # create mesh
97+ devices_array = max_utils .create_device_mesh (config )
98+ mesh = Mesh (devices_array , config .mesh_axes )
99+ rng = jax .random .key (config .seed )
100+
101+ vae_scale_factor_spatial = 2 ** len (pipeline .vae .temperal_downsample )
102+ video_processor = VideoProcessor (vae_scale_factor = vae_scale_factor_spatial )
103+
104+ # jit vae fun.
105+ p_vae_encode = jax .jit (functools .partial (vae_encode , vae = pipeline .vae , vae_cache = pipeline .vae_cache ))
106+
107+ # Load dataset
108+ ds = load_dataset (config .dataset_name , split = 'train' )
109+ ds = ds .shuffle (seed = config .seed )
110+ ds = ds .select_columns ([config .caption_column , config .image_column ])
111+ batch_size = 10
112+ for i in range (0 , len (ds ), batch_size ):
113+ rng , new_rng = jax .random .split (rng )
114+ text = ds [i :i + batch_size ]['text' ]
115+ video = ds [i :i + batch_size ]['image' ]
116+
117+ video = [np .expand_dims (np .array (i ), axis = 0 ) for i in video ]
118+ video = video_processor .preprocess_video (video , height = config .height , width = config .width )
119+ video = jnp .array (np .array (video ), dtype = config .weights_dtype )
120+ with mesh :
121+ latents = p_vae_encode (video = video , rng = new_rng )
122+ encoder_hidden_states = text_encode (pipeline , text )
123+ example = create_example (latents , encoder_hidden_states )
124+ writer .write (example )
125+ shard_record_count += batch_size
126+ global_record_count += batch_size
127+ if shard_record_count >= no_records_per_shard :
128+ writer .close ()
129+ writer = tf .io .TFRecordWriter (
130+ tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num , (global_record_count + no_records_per_shard ))
131+ )
132+ shard_record_count = 0
133+ tf_rec_num += 1
134+
135+
136+
137+ def run (config ):
138+ pipeline = WanPipeline .from_pretrained (config , load_transformer = False )
139+ # Don't need the transformer for preprocessing.
140+ generate_dataset (config , pipeline )
141+
142+
143+
144+ def main (argv : Sequence [str ]) -> None :
145+ pyconfig .initialize (argv )
146+ run (pyconfig .config )
147+
148+ if __name__ == "__main__" :
149+ app .run (main )
0 commit comments