Skip to content

Commit bce5842

Browse files
committed
ltx2.3 connectors loading
1 parent cc33249 commit bce5842

1 file changed

Lines changed: 31 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def load_connectors_weights(
4040
cpu = jax.local_devices(backend="cpu")[0]
4141
flattened_eval = flatten_dict(eval_shapes)
4242

43+
accumulated_stacked = {}
44+
4345
for pt_key, tensor in tensors.items():
4446
if not any(x in pt_key for x in ["connectors.", "video_embeddings_connector", "audio_embeddings_connector"]):
4547
continue
@@ -48,8 +50,35 @@ def load_connectors_weights(
4850
for replace_key, rename_to in LTX_2_3_CONNECTORS_KEYS_RENAME_DICT.items():
4951
flax_key_str = flax_key_str.replace(replace_key, rename_to)
5052

51-
flax_key = _tuple_str_to_int(flax_key_str.split("."))
52-
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
53+
segments = flax_key_str.split(".")
54+
55+
# Find if there is a layer index (digit)
56+
layer_idx = None
57+
base_segments = []
58+
for seg in segments:
59+
if seg.isdigit():
60+
layer_idx = int(seg)
61+
else:
62+
base_segments.append(seg)
63+
64+
if layer_idx is not None:
65+
base_key = _tuple_str_to_int(base_segments)
66+
if base_key not in accumulated_stacked:
67+
accumulated_stacked[base_key] = {}
68+
accumulated_stacked[base_key][layer_idx] = tensor
69+
else:
70+
flax_key = _tuple_str_to_int(segments)
71+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
72+
73+
# Now stack the accumulated ones
74+
for base_key, layers in accumulated_stacked.items():
75+
num_layers = max(layers.keys()) + 1
76+
if len(layers) != num_layers:
77+
raise ValueError(f"Missing layers for {base_key}, got {layers.keys()}")
78+
79+
sorted_tensors = [layers[i] for i in range(num_layers)]
80+
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
81+
flax_state_dict[base_key] = jax.device_put(stacked_tensor, device=cpu)
5382

5483
filtered_eval_shapes = {
5584
k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k)

0 commit comments

Comments
 (0)