Skip to content

Commit 4512951

Browse files
committed
fix
1 parent 9993393 commit 4512951

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/scripts/convert_ltx2_vae_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from safetensors.torch import load_file
1212
from huggingface_hub import hf_hub_download
1313
from flax import nnx
14-
14+
from flax import traverse_util
1515
def convert_ltx2_vae(hf_repo, output_path):
1616
# Load weights directly from Safetensors
1717
print(f"Downloading/Loading weights from {hf_repo}...")
@@ -67,7 +67,7 @@ def convert_ltx2_vae(hf_repo, output_path):
6767
# but nnx.Module usually has them after init if shape is provided?
6868
# Wait, nnx modules need to be split to see params.
6969
graphdef, state = nnx.split(model); params = state.filter(nnx.Param)
70-
flat_params = nnx.traverse_util.flatten_dict(params)
70+
flat_params = traverse_util.flatten_dict(params)
7171
sorted_flat_keys = sorted(flat_params.keys())
7272
for k in sorted_flat_keys:
7373
v = flat_params[k]

0 commit comments

Comments
 (0)