@@ -127,10 +127,10 @@ def test_load_vae_weights(self):
127127
128128 def test_load_vocoder_weights (self ):
129129 from maxdiffusion .models .ltx2 .vocoder_ltx2 import LTX2Vocoder
130- from unittest import mock
131- import torch
132130 from maxdiffusion .models .ltx2 .ltx2_utils import load_vocoder_weights
133131
132+ pretrained_model_name_or_path = "Lightricks/LTX-2"
133+
134134 config = {
135135 "hidden_channels" : 1024 ,
136136 "in_channels" : 128 ,
@@ -147,55 +147,16 @@ def test_load_vocoder_weights(self):
147147 state = nnx .state (model )
148148 eval_shapes = state .to_pure_dict ()
149149
150- # Create dummy PyTorch weights
151- pt_weights = {}
152- flat_shapes = flatten_dict (eval_shapes )
153-
154- for key , value in flat_shapes .items ():
155- # key is tuple of strings/ints, e.g. ('conv_in', 'kernel')
156- # Map to PyTorch key
157- pt_key_parts = []
158- is_upsampler = "upsamplers" in [str (k ) for k in key ]
159- is_kernel = False
160-
161- for part in key :
162- if str (part ) == "upsamplers" :
163- pt_key_parts .append ("ups" )
164- elif str (part ) == "resnets" :
165- pt_key_parts .append ("resblocks" )
166- elif str (part ) == "conv_out" :
167- pt_key_parts .append ("conv_post" )
168- elif str (part ) == "kernel" :
169- pt_key_parts .append ("weight" )
170- is_kernel = True
171- else :
172- pt_key_parts .append (str (part ))
173-
174- pt_key = "." .join (pt_key_parts )
175-
176- # Create tensor with PyTorch shape
177- shape = value .shape
178- if is_kernel :
179- if is_upsampler :
180- # Flax (K, I, O) -> PyTorch (I, O, K)
181- pt_shape = (shape [1 ], shape [2 ], shape [0 ])
182- else :
183- # Flax (K, I, O) -> PyTorch (O, I, K)
184- pt_shape = (shape [2 ], shape [1 ], shape [0 ])
185- else :
186- pt_shape = shape
187-
188- pt_weights [pt_key ] = jnp .array (torch .randn (pt_shape ).numpy ())
189-
190- with mock .patch ("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint" , return_value = pt_weights ):
191- loaded_weights = load_vocoder_weights (
192- pretrained_model_name_or_path = "dummy" ,
193- eval_shapes = eval_shapes ,
194- device = self .device ,
195- hf_download = False
196- )
150+ print ("Loading Vocoder Weights..." )
151+ loaded_weights = load_vocoder_weights (
152+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
153+ eval_shapes = eval_shapes ,
154+ device = self .device ,
155+ hf_download = True
156+ )
197157
198158 # Validate
159+ print ("Validating Vocoder Weights..." )
199160 validate_flax_state_dict (eval_shapes , flatten_dict (loaded_weights ))
200161 print ("Vocoder Weights Validated Successfully!" )
201162
0 commit comments