@@ -406,3 +406,95 @@ def load_ltx2_vae(
406406 validate_flax_state_dict (eval_shapes , flax_state_dict )
407407 flax_state_dict = unflatten_dict (flax_state_dict )
408408 return flax_state_dict
409+
410+
411+ def load_ltx2_vocoder (
412+ pretrained_model_name_or_path : str ,
413+ eval_shapes : dict ,
414+ device : str ,
415+ hf_download : bool = True ,
416+ subfolder : str = "vocoder" ,
417+ ):
418+ device = jax .local_devices (backend = device )[0 ]
419+ # Vocoder weights are usually in diffusion_pytorch_model.safetensors inside "vocoder" folder
420+ filename = "diffusion_pytorch_model.safetensors"
421+
422+ local_files = False
423+ if os .path .isdir (pretrained_model_name_or_path ):
424+ ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
425+ if os .path .isfile (ckpt_path ):
426+ local_files = True
427+
428+ tensors = {}
429+ if hf_download and not local_files :
430+ try :
431+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
432+ except Exception as e :
433+ raise e
434+ elif local_files :
435+ # ckpt_path already set
436+ pass
437+ else :
438+ # If not hf_download and not local, we can't load unless path is direct file
439+ if os .path .isfile (pretrained_model_name_or_path ):
440+ ckpt_path = pretrained_model_name_or_path
441+ else :
442+ # Maybe it's just the repo id and user expects download but hf_download=False?
443+ pass
444+
445+ max_logging .log (f"Load and port { pretrained_model_name_or_path } Vocoder from { ckpt_path } " )
446+
447+ with safe_open (ckpt_path , framework = "pt" ) as f :
448+ for k in f .keys ():
449+ tensors [k ] = torch2jax (f .get_tensor (k ))
450+
451+ flax_state_dict = {}
452+ cpu = jax .local_devices (backend = "cpu" )[0 ]
453+
454+ # Flatten eval_shapes to find valid keys/shapes
455+ flattened_eval_shapes = flatten_dict (eval_shapes )
456+ random_flax_state_dict = {}
457+ for key in flattened_eval_shapes :
458+ string_tuple = tuple ([str (item ) for item in key ])
459+ random_flax_state_dict [string_tuple ] = flattened_eval_shapes [key ]
460+ del flattened_eval_shapes
461+
462+ for pt_key , tensor in tensors .items ():
463+ renamed_pt_key = pt_key
464+
465+ # Mapping for LTX2Vocoder
466+ # PyTorch (Diffusers likely) -> Flax LTX2Vocoder
467+
468+ # conv_in -> conv_in.conv (nnx.Conv doesn't usually nest .conv unless we use our wrapper)
469+ # But checking vocoder_ltx2.py, self.conv_in = nnx.Conv(...)
470+ # So key is conv_in.kernel or conv_in.weight -> conv_in.kernel
471+
472+ # Diffusers usually uses: "conv_in.weight", "conv_in.bias"
473+
474+ # If we use nnx.Conv directly:
475+ # conv_in.weight -> conv_in.kernel
476+ # conv_in.bias -> conv_in.bias
477+
478+ # Does modeling_flax_pytorch_utils.rename_key handle .weight -> .kernel? Yes usually.
479+
480+ # ups.X.conv.weight (in Diffusers) -> upsamplers.layers.X.kernel (in Flax nnx.ConvTranspose)
481+ renamed_pt_key = renamed_pt_key .replace ("ups." , "upsamplers.layers." )
482+
483+ # resblocks.X.convs1.Y.weight -> resnets.layers.X.convs1.layers.Y.kernel
484+ renamed_pt_key = renamed_pt_key .replace ("resblocks." , "resnets.layers." )
485+ renamed_pt_key = renamed_pt_key .replace ("convs1." , "convs1.layers." )
486+ renamed_pt_key = renamed_pt_key .replace ("convs2." , "convs2.layers." )
487+
488+ # conv_out -> conv_out
489+
490+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
491+
492+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers = False )
493+ flax_key = _tuple_str_to_int (flax_key )
494+
495+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
496+
497+ validate_flax_state_dict (eval_shapes , flax_state_dict )
498+ flax_state_dict = unflatten_dict (flax_state_dict )
499+ return flax_state_dict
500+
0 commit comments