@@ -177,21 +177,18 @@ def load_dataset(self, mesh, is_training=True):
177177 "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
178178 )
179179
180- feature_description_train = {
180+ feature_description = {
181181 "latents" : tf .io .FixedLenFeature ([], tf .string ),
182182 "encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
183183 }
184184
185+ if not is_training :
186+ feature_description ["timesteps" ] = tf .io .FixedLenFeature ([], tf .int64 )
187+
185188 def prepare_sample_train (features ):
186189 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
187190 encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
188191 return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states }
189-
190- feature_description_eval = {
191- "latents" : tf .io .FixedLenFeature ([], tf .string ),
192- "encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
193- "timesteps" : tf .io .FixedLenFeature ([], tf .int64 ),
194- }
195192
196193 def prepare_sample_eval (features ):
197194 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
@@ -206,7 +203,7 @@ def prepare_sample_eval(features):
206203 jax .process_count (),
207204 mesh ,
208205 config .global_batch_size_to_load ,
209- feature_description = feature_description_train if is_training else feature_description_eval ,
206+ feature_description = feature_description ,
210207 prepare_sample_fn = prepare_sample_train if is_training else prepare_sample_eval ,
211208 is_training = is_training ,
212209 )
0 commit comments