@@ -398,3 +398,109 @@ def load_vocoder_weights(
398398 validate_flax_state_dict (eval_shapes , flax_state_dict )
399399 return unflatten_dict (flax_state_dict )
400400
401+
402+ def rename_for_ltx2_connector (key ):
403+ key = key .replace ("video_connector" , "video_embeddings_connector" )
404+ key = key .replace ("audio_connector" , "audio_embeddings_connector" )
405+ key = key .replace ("text_proj_in" , "feature_extractor.linear" )
406+
407+ # Transformer blocks mapping
408+ if "transformer_blocks" in key :
409+ key = key .replace ("transformer_blocks" , "stacked_blocks" )
410+ # Handle FF
411+ key = key .replace ("ff.net.0.proj" , "ff.proj1" )
412+ key = key .replace ("ff.net.2" , "ff.proj2" )
413+
414+ # Validation/Weight suffix
415+ if key .endswith (".weight" ):
416+ # Check if it's a norm with usage_scale=True (attn norms)
417+ if "norm_q" in key or "norm_k" in key :
418+ key = key .replace (".weight" , ".scale" )
419+ # Check if it's a norm with usage_scale=False (block norms) -> No, these don't exist in checkpoint!
420+ else :
421+ key = key .replace (".weight" , ".kernel" )
422+
423+ return key
424+
425+ def load_connector_weights (
426+ pretrained_model_name_or_path : str ,
427+ eval_shapes : dict ,
428+ device : str ,
429+ hf_download : bool = True ,
430+ subfolder : str = "connectors"
431+ ):
432+ tensors = load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device )
433+ flax_state_dict = {}
434+ cpu = jax .local_devices (backend = "cpu" )[0 ]
435+
436+ # Store stacked weights: grouped_weights[connector][param_name] = {layer_idx: tensor}
437+ grouped_weights = {
438+ "video_embeddings_connector" : {},
439+ "audio_embeddings_connector" : {}
440+ }
441+
442+ for pt_key , tensor in tensors .items ():
443+ key = rename_for_ltx2_connector (pt_key )
444+
445+ # Check for transpose (Linear layers)
446+ if key .endswith (".kernel" ):
447+ if tensor .ndim == 2 :
448+ tensor = tensor .transpose (1 , 0 )
449+
450+ if "stacked_blocks" in key :
451+ # key format: {connector}.stacked_blocks.{layer_idx}.{rest}
452+ parts = key .split ("." )
453+ # Find stacked_blocks index
454+ try :
455+ sb_index = parts .index ("stacked_blocks" )
456+ layer_idx = int (parts [sb_index + 1 ])
457+ connector = parts [0 ]
458+
459+ # Reconstruct param name without layer index
460+ # e.g. video_embeddings_connector.stacked_blocks.attn1...
461+ param_parts = parts [:sb_index + 1 ] + parts [sb_index + 2 :]
462+ param_name = tuple (param_parts )
463+
464+ if connector in grouped_weights :
465+ if param_name not in grouped_weights [connector ]:
466+ grouped_weights [connector ][param_name ] = {}
467+ grouped_weights [connector ][param_name ][layer_idx ] = tensor
468+ continue
469+ except (ValueError , IndexError ):
470+ pass
471+
472+ # Non-stacked keys
473+ key_tuple = tuple (key .split ("." ))
474+
475+ # Handle int conversion for parts
476+ final_key_tuple = []
477+ for p in key_tuple :
478+ if p .isdigit (): final_key_tuple .append (int (p ))
479+ else : final_key_tuple .append (p )
480+ final_key_tuple = tuple (final_key_tuple )
481+
482+ flax_state_dict [final_key_tuple ] = jax .device_put (tensor , device = cpu )
483+
484+ # Process grouped weights
485+ for connector , params in grouped_weights .items ():
486+ for param_name , layers in params .items ():
487+ # Sort by layer index and stack
488+ sorted_layers = sorted (layers .keys ())
489+ # Assuming contiguous layers 0..N-1
490+ stacked_tensor = jnp .stack ([layers [i ] for i in sorted_layers ], axis = 0 )
491+
492+ # Param name tuple
493+ final_param_name = []
494+ for p in param_name :
495+ if isinstance (p , str ) and p .isdigit (): final_param_name .append (int (p ))
496+ else : final_param_name .append (p )
497+ final_param_name = tuple (final_param_name )
498+
499+ flax_state_dict [final_param_name ] = jax .device_put (stacked_tensor , device = cpu )
500+
501+ # Clean up and return
502+ del tensors
503+ jax .clear_caches ()
504+ validate_flax_state_dict (eval_shapes , flax_state_dict )
505+ return unflatten_dict (flax_state_dict )
506+
0 commit comments