@@ -404,24 +404,17 @@ def rename_for_ltx2_connector(key):
404404 key = key .replace ("audio_connector" , "audio_embeddings_connector" )
405405 key = key .replace ("text_proj_in" , "feature_extractor.linear" )
406406
407- # Transformer blocks mapping
408407 if "transformer_blocks" in key :
409408 key = key .replace ("transformer_blocks" , "stacked_blocks" )
410- # Handle FF
411409 key = key .replace ("ff.net.0.proj" , "ff.proj1" )
412410 key = key .replace ("ff.net.2" , "ff.proj2" )
413- # Handle to_out
414411 key = key .replace ("to_out.0" , "to_out" )
415412
416- # Validation/Weight suffix
417413 if key .endswith (".weight" ):
418- # Check if it's a norm with usage_scale=True (attn norms)
419414 if "norm_q" in key or "norm_k" in key :
420415 key = key .replace (".weight" , ".scale" )
421- # Check if it's a norm with usage_scale=False (block norms) -> No, these don't exist in checkpoint!
422416 else :
423417 key = key .replace (".weight" , ".kernel" )
424-
425418 return key
426419
427420def load_connector_weights (
@@ -434,8 +427,7 @@ def load_connector_weights(
434427 tensors = load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device )
435428 flax_state_dict = {}
436429 cpu = jax .local_devices (backend = "cpu" )[0 ]
437-
438- # Store stacked weights: grouped_weights[connector][param_name] = {layer_idx: tensor}
430+
439431 grouped_weights = {
440432 "video_embeddings_connector" : {},
441433 "audio_embeddings_connector" : {}
@@ -444,22 +436,17 @@ def load_connector_weights(
444436 for pt_key , tensor in tensors .items ():
445437 key = rename_for_ltx2_connector (pt_key )
446438
447- # Check for transpose (Linear layers)
448439 if key .endswith (".kernel" ):
449- if tensor .ndim == 2 :
450- tensor = tensor .transpose (1 , 0 )
451-
440+ if tensor .ndim == 2 :
441+ tensor = tensor .transpose (1 , 0 )
442+
452443 if "stacked_blocks" in key :
453- # key format: {connector}.stacked_blocks.{layer_idx}.{rest}
454444 parts = key .split ("." )
455- # Find stacked_blocks index
456445 try :
457446 sb_index = parts .index ("stacked_blocks" )
458447 layer_idx = int (parts [sb_index + 1 ])
459448 connector = parts [0 ]
460449
461- # Reconstruct param name without layer index
462- # e.g. video_embeddings_connector.stacked_blocks.attn1...
463450 param_parts = parts [:sb_index + 1 ] + parts [sb_index + 2 :]
464451 param_name = tuple (param_parts )
465452
@@ -471,10 +458,8 @@ def load_connector_weights(
471458 except (ValueError , IndexError ):
472459 pass
473460
474- # Non-stacked keys
475461 key_tuple = tuple (key .split ("." ))
476462
477- # Handle int conversion for parts
478463 final_key_tuple = []
479464 for p in key_tuple :
480465 if p .isdigit (): final_key_tuple .append (int (p ))
@@ -483,15 +468,11 @@ def load_connector_weights(
483468
484469 flax_state_dict [final_key_tuple ] = jax .device_put (tensor , device = cpu )
485470
486- # Process grouped weights
487471 for connector , params in grouped_weights .items ():
488472 for param_name , layers in params .items ():
489- # Sort by layer index and stack
490473 sorted_layers = sorted (layers .keys ())
491- # Assuming contiguous layers 0..N-1
492474 stacked_tensor = jnp .stack ([layers [i ] for i in sorted_layers ], axis = 0 )
493475
494- # Param name tuple
495476 final_param_name = []
496477 for p in param_name :
497478 if isinstance (p , str ) and p .isdigit (): final_param_name .append (int (p ))
@@ -500,7 +481,6 @@ def load_connector_weights(
500481
501482 flax_state_dict [final_param_name ] = jax .device_put (stacked_tensor , device = cpu )
502483
503- # Clean up and return
504484 del tensors
505485 jax .clear_caches ()
506486 validate_flax_state_dict (eval_shapes , flax_state_dict )
0 commit comments