Skip to content

Commit 6422e93

Browse files
committed
fix
1 parent 478030e commit 6422e93

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def load_vae_weights(
238238
pt_list.append("resnets")
239239
elif name in ["down_blocks", "up_blocks", "downsamplers", "upsamplers"]:
240240
pt_list.append(name)
241-
pt_list.append(idx)
241+
pt_list.append(str(idx))
242242
else:
243243
pt_list.append(part)
244244
else:

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_load_transformer_weights(self):
8383
# But wait, validate_flax_state_dict expects a dict of params, usually just the params subtree.
8484

8585
# We can extract params from state
86-
state = nnx.state(model)
86+
state = nnx.state(self.transformer)
8787
# Filter for params?
8888
# Usually validate_flax_state_dict expects the full PyTree or a specific dict.
8989
# state is a State object, acts like a Mapping

0 commit comments

Comments
 (0)