Skip to content

Commit 4e985c3

Browse files
committed
weights loading for vocoder
1 parent a257f4e commit 4e985c3

2 files changed

Lines changed: 124 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,53 @@ def load_vae_weights(
354354
jax.clear_caches()
355355
return flax_state_dict
356356

357+
def rename_for_ltx2_vocoder(key):
358+
key = key.replace("ups", "upsamplers")
359+
key = key.replace("resblocks", "resnets")
360+
key = key.replace("conv_post", "conv_out")
361+
return key
362+
363+
364+
def load_vocoder_weights(
365+
pretrained_model_name_or_path: str,
366+
eval_shapes: dict,
367+
device: str,
368+
hf_download: bool = True,
369+
subfolder: str = "vocoder"
370+
):
371+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
372+
373+
flax_state_dict = {}
374+
cpu = jax.local_devices(backend="cpu")[0]
375+
376+
for pt_key, tensor in tensors.items():
377+
# Initial renaming
378+
key = rename_for_ltx2_vocoder(pt_key)
379+
parts = key.split(".")
380+
381+
flax_key_parts = []
382+
for part in parts:
383+
if part.isdigit():
384+
flax_key_parts.append(int(part))
385+
else:
386+
flax_key_parts.append(part)
387+
388+
if flax_key_parts[-1] == "weight":
389+
flax_key_parts[-1] = "kernel"
390+
391+
flax_key = tuple(flax_key_parts)
392+
393+
# Transpose weights
394+
if flax_key[-1] == "kernel":
395+
if "upsamplers" in flax_key:
396+
# ConvTranspose: (In, Out, K) -> (K, In, Out)
397+
tensor = tensor.transpose(2, 0, 1)
398+
else:
399+
# Conv: (Out, In, K) -> (K, In, Out)
400+
tensor = tensor.transpose(2, 1, 0)
401+
402+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
403+
404+
validate_flax_state_dict(eval_shapes, flax_state_dict)
405+
return unflatten_dict(flax_state_dict)
406+

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def test_load_vae_weights(self):
113113
)
114114

115115
print("Validating VAE Weights...")
116-
# Filter out dropout/rngs keys from eval_shapes as they are not expected in weights
117116
filtered_eval_shapes = {}
118117
flat_eval_shapes = flatten_dict(eval_shapes)
119118
for k, v in flat_eval_shapes.items():
@@ -126,5 +125,79 @@ def test_load_vae_weights(self):
126125
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flatten_dict(loaded_weights))
127126
print("VAE Weights Validated Successfully!")
128127

128+
def test_load_vocoder_weights(self):
129+
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
130+
from unittest import mock
131+
import torch
132+
from maxdiffusion.models.ltx2.ltx2_utils import load_vocoder_weights
133+
134+
config = {
135+
"hidden_channels": 1024,
136+
"in_channels": 128,
137+
"leaky_relu_negative_slope": 0.1,
138+
"out_channels": 2,
139+
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
140+
"resnet_kernel_sizes": [3, 7, 11],
141+
"upsample_factors": [6, 5, 2, 2, 2],
142+
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
143+
"rngs": nnx.Rngs(0)
144+
}
145+
with jax.default_device(jax.devices("cpu")[0]):
146+
model = LTX2Vocoder(**config)
147+
state = nnx.state(model)
148+
eval_shapes = state.to_pure_dict()
149+
150+
# Create dummy PyTorch weights
151+
pt_weights = {}
152+
flat_shapes = flatten_dict(eval_shapes)
153+
154+
for key, value in flat_shapes.items():
155+
# key is tuple of strings/ints, e.g. ('conv_in', 'kernel')
156+
# Map to PyTorch key
157+
pt_key_parts = []
158+
is_upsampler = "upsamplers" in [str(k) for k in key]
159+
is_kernel = False
160+
161+
for part in key:
162+
if str(part) == "upsamplers":
163+
pt_key_parts.append("ups")
164+
elif str(part) == "resnets":
165+
pt_key_parts.append("resblocks")
166+
elif str(part) == "conv_out":
167+
pt_key_parts.append("conv_post")
168+
elif str(part) == "kernel":
169+
pt_key_parts.append("weight")
170+
is_kernel = True
171+
else:
172+
pt_key_parts.append(str(part))
173+
174+
pt_key = ".".join(pt_key_parts)
175+
176+
# Create tensor with PyTorch shape
177+
shape = value.shape
178+
if is_kernel:
179+
if is_upsampler:
180+
# Flax (K, I, O) -> PyTorch (I, O, K)
181+
pt_shape = (shape[1], shape[2], shape[0])
182+
else:
183+
# Flax (K, I, O) -> PyTorch (O, I, K)
184+
pt_shape = (shape[2], shape[1], shape[0])
185+
else:
186+
pt_shape = shape
187+
188+
pt_weights[pt_key] = torch.randn(pt_shape)
189+
190+
with mock.patch("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint", return_value=pt_weights):
191+
loaded_weights = load_vocoder_weights(
192+
pretrained_model_name_or_path="dummy",
193+
eval_shapes=eval_shapes,
194+
device=self.device,
195+
hf_download=False
196+
)
197+
198+
# Validate
199+
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
200+
print("Vocoder Weights Validated Successfully!")
201+
129202
if __name__ == "__main__":
130203
unittest.main()

0 commit comments

Comments
 (0)