Skip to content

Commit bf6bea9

Browse files
committed
fix
1 parent 3f4d3d4 commit bf6bea9

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,21 +177,18 @@ def load_dataset(self, mesh, is_training=True):
177177
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
178178
)
179179

180-
feature_description_train = {
180+
feature_description = {
181181
"latents": tf.io.FixedLenFeature([], tf.string),
182182
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
183183
}
184184

185+
if not is_training:
186+
feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64)
187+
185188
def prepare_sample_train(features):
186189
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
187190
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
188191
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states}
189-
190-
feature_description_eval = {
191-
"latents": tf.io.FixedLenFeature([], tf.string),
192-
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
193-
"timesteps": tf.io.FixedLenFeature([], tf.int64),
194-
}
195192

196193
def prepare_sample_eval(features):
197194
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
@@ -206,7 +203,7 @@ def prepare_sample_eval(features):
206203
jax.process_count(),
207204
mesh,
208205
config.global_batch_size_to_load,
209-
feature_description=feature_description_train if is_training else feature_description_eval,
206+
feature_description=feature_description,
210207
prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval,
211208
is_training=is_training,
212209
)

0 commit comments

Comments
 (0)