@@ -40,6 +40,8 @@ def load_connectors_weights(
4040 cpu = jax .local_devices (backend = "cpu" )[0 ]
4141 flattened_eval = flatten_dict (eval_shapes )
4242
43+ accumulated_stacked = {}
44+
4345 for pt_key , tensor in tensors .items ():
4446 if not any (x in pt_key for x in ["connectors." , "video_embeddings_connector" , "audio_embeddings_connector" ]):
4547 continue
@@ -48,8 +50,35 @@ def load_connectors_weights(
4850 for replace_key , rename_to in LTX_2_3_CONNECTORS_KEYS_RENAME_DICT .items ():
4951 flax_key_str = flax_key_str .replace (replace_key , rename_to )
5052
51- flax_key = _tuple_str_to_int (flax_key_str .split ("." ))
52- flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
53+ segments = flax_key_str .split ("." )
54+
55+ # Find if there is a layer index (digit)
56+ layer_idx = None
57+ base_segments = []
58+ for seg in segments :
59+ if seg .isdigit ():
60+ layer_idx = int (seg )
61+ else :
62+ base_segments .append (seg )
63+
64+ if layer_idx is not None :
65+ base_key = _tuple_str_to_int (base_segments )
66+ if base_key not in accumulated_stacked :
67+ accumulated_stacked [base_key ] = {}
68+ accumulated_stacked [base_key ][layer_idx ] = tensor
69+ else :
70+ flax_key = _tuple_str_to_int (segments )
71+ flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
72+
73+ # Now stack the accumulated ones
74+ for base_key , layers in accumulated_stacked .items ():
75+ num_layers = max (layers .keys ()) + 1
76+ if len (layers ) != num_layers :
77+ raise ValueError (f"Missing layers for { base_key } , got { layers .keys ()} " )
78+
79+ sorted_tensors = [layers [i ] for i in range (num_layers )]
80+ stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
81+ flax_state_dict [base_key ] = jax .device_put (stacked_tensor , device = cpu )
5382
5483 filtered_eval_shapes = {
5584 k : v for k , v in flattened_eval .items () if not any ("dropout" in str (x ) or "rngs" in str (x ) for x in k )
0 commit comments