Skip to content

Commit 844c956

Browse files
committed
loading from real checkpoints
1 parent 04ed46d commit 844c956

2 files changed

Lines changed: 10 additions & 51 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,6 @@ def load_vae_weights(
346346
continue
347347
filtered_eval_shapes[k] = v
348348

349-
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
350-
351349
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
352350
flax_state_dict = unflatten_dict(flax_state_dict)
353351
del tensors

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def test_load_vae_weights(self):
127127

128128
def test_load_vocoder_weights(self):
129129
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
130-
from unittest import mock
131-
import torch
132130
from maxdiffusion.models.ltx2.ltx2_utils import load_vocoder_weights
133131

132+
pretrained_model_name_or_path = "Lightricks/LTX-2"
133+
134134
config = {
135135
"hidden_channels": 1024,
136136
"in_channels": 128,
@@ -147,55 +147,16 @@ def test_load_vocoder_weights(self):
147147
state = nnx.state(model)
148148
eval_shapes = state.to_pure_dict()
149149

150-
# Create dummy PyTorch weights
151-
pt_weights = {}
152-
flat_shapes = flatten_dict(eval_shapes)
153-
154-
for key, value in flat_shapes.items():
155-
# key is tuple of strings/ints, e.g. ('conv_in', 'kernel')
156-
# Map to PyTorch key
157-
pt_key_parts = []
158-
is_upsampler = "upsamplers" in [str(k) for k in key]
159-
is_kernel = False
160-
161-
for part in key:
162-
if str(part) == "upsamplers":
163-
pt_key_parts.append("ups")
164-
elif str(part) == "resnets":
165-
pt_key_parts.append("resblocks")
166-
elif str(part) == "conv_out":
167-
pt_key_parts.append("conv_post")
168-
elif str(part) == "kernel":
169-
pt_key_parts.append("weight")
170-
is_kernel = True
171-
else:
172-
pt_key_parts.append(str(part))
173-
174-
pt_key = ".".join(pt_key_parts)
175-
176-
# Create tensor with PyTorch shape
177-
shape = value.shape
178-
if is_kernel:
179-
if is_upsampler:
180-
# Flax (K, I, O) -> PyTorch (I, O, K)
181-
pt_shape = (shape[1], shape[2], shape[0])
182-
else:
183-
# Flax (K, I, O) -> PyTorch (O, I, K)
184-
pt_shape = (shape[2], shape[1], shape[0])
185-
else:
186-
pt_shape = shape
187-
188-
pt_weights[pt_key] = jnp.array(torch.randn(pt_shape).numpy())
189-
190-
with mock.patch("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint", return_value=pt_weights):
191-
loaded_weights = load_vocoder_weights(
192-
pretrained_model_name_or_path="dummy",
193-
eval_shapes=eval_shapes,
194-
device=self.device,
195-
hf_download=False
196-
)
150+
print("Loading Vocoder Weights...")
151+
loaded_weights = load_vocoder_weights(
152+
pretrained_model_name_or_path=pretrained_model_name_or_path,
153+
eval_shapes=eval_shapes,
154+
device=self.device,
155+
hf_download=True
156+
)
197157

198158
# Validate
159+
print("Validating Vocoder Weights...")
199160
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
200161
print("Vocoder Weights Validated Successfully!")
201162

0 commit comments

Comments
 (0)