We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e6beb89 commit da0a909Copy full SHA for da0a909
1 file changed
src/maxdiffusion/models/ltx2/ltx2_3_utils.py
@@ -369,13 +369,8 @@ def load_vocoder_weights_2_3(
369
flax_state_dict = {}
370
cpu = jax.local_devices(backend="cpu")[0]
371
372
- from flax.traverse_util import flatten_dict
373
- flat_eval = flatten_dict(eval_shapes)
374
- print("Expected vocoder keys:", [k for k in flat_eval.keys() if "mel_stft" in str(k)])
375
-
376
for pt_key, tensor in tensors.items():
377
# Keys are already filtered and stripped of "vocoder." by load_and_segregate
378
- print("Processing pt_key:", pt_key)
379
key = rename_for_ltx2_3_vocoder(pt_key)
380
381
# Always apply LTX-2.3 specific replacement
0 commit comments