@@ -113,7 +113,6 @@ def test_load_vae_weights(self):
113113 )
114114
115115 print ("Validating VAE Weights..." )
116- # Filter out dropout/rngs keys from eval_shapes as they are not expected in weights
117116 filtered_eval_shapes = {}
118117 flat_eval_shapes = flatten_dict (eval_shapes )
119118 for k , v in flat_eval_shapes .items ():
@@ -126,5 +125,79 @@ def test_load_vae_weights(self):
126125 validate_flax_state_dict (unflatten_dict (filtered_eval_shapes ), flatten_dict (loaded_weights ))
127126 print ("VAE Weights Validated Successfully!" )
128127
128+ def test_load_vocoder_weights (self ):
129+ from maxdiffusion .models .ltx2 .vocoder_ltx2 import LTX2Vocoder
130+ from unittest import mock
131+ import torch
132+ from maxdiffusion .models .ltx2 .ltx2_utils import load_vocoder_weights
133+
134+ config = {
135+ "hidden_channels" : 1024 ,
136+ "in_channels" : 128 ,
137+ "leaky_relu_negative_slope" : 0.1 ,
138+ "out_channels" : 2 ,
139+ "resnet_dilations" : [[1 , 3 , 5 ], [1 , 3 , 5 ], [1 , 3 , 5 ]],
140+ "resnet_kernel_sizes" : [3 , 7 , 11 ],
141+ "upsample_factors" : [6 , 5 , 2 , 2 , 2 ],
142+ "upsample_kernel_sizes" : [16 , 15 , 8 , 4 , 4 ],
143+ "rngs" : nnx .Rngs (0 )
144+ }
145+ with jax .default_device (jax .devices ("cpu" )[0 ]):
146+ model = LTX2Vocoder (** config )
147+ state = nnx .state (model )
148+ eval_shapes = state .to_pure_dict ()
149+
150+ # Create dummy PyTorch weights
151+ pt_weights = {}
152+ flat_shapes = flatten_dict (eval_shapes )
153+
154+ for key , value in flat_shapes .items ():
155+ # key is tuple of strings/ints, e.g. ('conv_in', 'kernel')
156+ # Map to PyTorch key
157+ pt_key_parts = []
158+ is_upsampler = "upsamplers" in [str (k ) for k in key ]
159+ is_kernel = False
160+
161+ for part in key :
162+ if str (part ) == "upsamplers" :
163+ pt_key_parts .append ("ups" )
164+ elif str (part ) == "resnets" :
165+ pt_key_parts .append ("resblocks" )
166+ elif str (part ) == "conv_out" :
167+ pt_key_parts .append ("conv_post" )
168+ elif str (part ) == "kernel" :
169+ pt_key_parts .append ("weight" )
170+ is_kernel = True
171+ else :
172+ pt_key_parts .append (str (part ))
173+
174+ pt_key = "." .join (pt_key_parts )
175+
176+ # Create tensor with PyTorch shape
177+ shape = value .shape
178+ if is_kernel :
179+ if is_upsampler :
180+ # Flax (K, I, O) -> PyTorch (I, O, K)
181+ pt_shape = (shape [1 ], shape [2 ], shape [0 ])
182+ else :
183+ # Flax (K, I, O) -> PyTorch (O, I, K)
184+ pt_shape = (shape [2 ], shape [1 ], shape [0 ])
185+ else :
186+ pt_shape = shape
187+
188+ pt_weights [pt_key ] = torch .randn (pt_shape )
189+
190+ with mock .patch ("maxdiffusion.models.ltx2.ltx2_utils.load_sharded_checkpoint" , return_value = pt_weights ):
191+ loaded_weights = load_vocoder_weights (
192+ pretrained_model_name_or_path = "dummy" ,
193+ eval_shapes = eval_shapes ,
194+ device = self .device ,
195+ hf_download = False
196+ )
197+
198+ # Validate
199+ validate_flax_state_dict (eval_shapes , flatten_dict (loaded_weights ))
200+ print ("Vocoder Weights Validated Successfully!" )
201+
129202if __name__ == "__main__" :
130203 unittest .main ()
0 commit comments