Skip to content

Commit d73e6b7

Browse files
committed
connectors
1 parent e9afcc4 commit d73e6b7

4 files changed

Lines changed: 135 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,109 @@ def load_vocoder_weights(
398398
validate_flax_state_dict(eval_shapes, flax_state_dict)
399399
return unflatten_dict(flax_state_dict)
400400

401+
402+
def rename_for_ltx2_connector(key):
403+
key = key.replace("video_connector", "video_embeddings_connector")
404+
key = key.replace("audio_connector", "audio_embeddings_connector")
405+
key = key.replace("text_proj_in", "feature_extractor.linear")
406+
407+
# Transformer blocks mapping
408+
if "transformer_blocks" in key:
409+
key = key.replace("transformer_blocks", "stacked_blocks")
410+
# Handle FF
411+
key = key.replace("ff.net.0.proj", "ff.proj1")
412+
key = key.replace("ff.net.2", "ff.proj2")
413+
414+
# Validation/Weight suffix
415+
if key.endswith(".weight"):
416+
# Check if it's a norm with usage_scale=True (attn norms)
417+
if "norm_q" in key or "norm_k" in key:
418+
key = key.replace(".weight", ".scale")
419+
# Check if it's a norm with usage_scale=False (block norms) -> No, these don't exist in checkpoint!
420+
else:
421+
key = key.replace(".weight", ".kernel")
422+
423+
return key
424+
425+
def load_connector_weights(
426+
pretrained_model_name_or_path: str,
427+
eval_shapes: dict,
428+
device: str,
429+
hf_download: bool = True,
430+
subfolder: str = "connectors"
431+
):
432+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
433+
flax_state_dict = {}
434+
cpu = jax.local_devices(backend="cpu")[0]
435+
436+
# Store stacked weights: grouped_weights[connector][param_name] = {layer_idx: tensor}
437+
grouped_weights = {
438+
"video_embeddings_connector": {},
439+
"audio_embeddings_connector": {}
440+
}
441+
442+
for pt_key, tensor in tensors.items():
443+
key = rename_for_ltx2_connector(pt_key)
444+
445+
# Check for transpose (Linear layers)
446+
if key.endswith(".kernel"):
447+
if tensor.ndim == 2:
448+
tensor = tensor.transpose(1, 0)
449+
450+
if "stacked_blocks" in key:
451+
# key format: {connector}.stacked_blocks.{layer_idx}.{rest}
452+
parts = key.split(".")
453+
# Find stacked_blocks index
454+
try:
455+
sb_index = parts.index("stacked_blocks")
456+
layer_idx = int(parts[sb_index + 1])
457+
connector = parts[0]
458+
459+
# Reconstruct param name without layer index
460+
# e.g. video_embeddings_connector.stacked_blocks.attn1...
461+
param_parts = parts[:sb_index+1] + parts[sb_index+2:]
462+
param_name = tuple(param_parts)
463+
464+
if connector in grouped_weights:
465+
if param_name not in grouped_weights[connector]:
466+
grouped_weights[connector][param_name] = {}
467+
grouped_weights[connector][param_name][layer_idx] = tensor
468+
continue
469+
except (ValueError, IndexError):
470+
pass
471+
472+
# Non-stacked keys
473+
key_tuple = tuple(key.split("."))
474+
475+
# Handle int conversion for parts
476+
final_key_tuple = []
477+
for p in key_tuple:
478+
if p.isdigit(): final_key_tuple.append(int(p))
479+
else: final_key_tuple.append(p)
480+
final_key_tuple = tuple(final_key_tuple)
481+
482+
flax_state_dict[final_key_tuple] = jax.device_put(tensor, device=cpu)
483+
484+
# Process grouped weights
485+
for connector, params in grouped_weights.items():
486+
for param_name, layers in params.items():
487+
# Sort by layer index and stack
488+
sorted_layers = sorted(layers.keys())
489+
# Assuming contiguous layers 0..N-1
490+
stacked_tensor = jnp.stack([layers[i] for i in sorted_layers], axis=0)
491+
492+
# Param name tuple
493+
final_param_name = []
494+
for p in param_name:
495+
if isinstance(p, str) and p.isdigit(): final_param_name.append(int(p))
496+
else: final_param_name.append(p)
497+
final_param_name = tuple(final_param_name)
498+
499+
flax_state_dict[final_param_name] = jax.device_put(stacked_tensor, device=cpu)
500+
501+
# Clean up and return
502+
del tensors
503+
jax.clear_caches()
504+
validate_flax_state_dict(eval_shapes, flax_state_dict)
505+
return unflatten_dict(flax_state_dict)
506+

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(
6363
rngs=rngs,
6464
)
6565
self.ff = FeedForward(dim, dim_out=dim, rngs=rngs)
66-
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
67-
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
66+
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs)
67+
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs)
6868

6969
def __call__(
7070
self,
@@ -129,7 +129,7 @@ def create_block(rngs):
129129
)
130130

131131
self.final_norm = nnx.RMSNorm(
132-
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs
132+
self.dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs
133133
)
134134

135135
def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
# Feature Extractor Config
3939
gemma_dim: int = 3840, # Gemma-3-12b
4040
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
41-
projection_dim: int = 4096, # LTX-2 conditioning dim
41+
projection_dim: int = 3840, # LTX-2 conditioning dim
4242
# Connector Config
4343
connector_heads: int = 32,
4444
connector_head_dim: int = 128,
@@ -98,7 +98,7 @@ def __init__(
9898
# Feature Extractor Config (Shared)
9999
gemma_dim: int = 3840, # Gemma-3-12b
100100
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
101-
projection_dim: int = 4096,
101+
projection_dim: int = 3840,
102102
# Connector Config
103103
connector_heads: int = 32,
104104
connector_head_dim: int = 128,

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,29 @@ def test_load_vocoder_weights(self):
160160
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
161161
print("Vocoder Weights Validated Successfully!")
162162

163+
def test_load_connector_weights(self):
164+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
165+
from maxdiffusion.models.ltx2.ltx2_utils import load_connector_weights
166+
167+
pretrained_model_name_or_path = "Lightricks/LTX-2"
168+
169+
with jax.default_device(jax.devices("cpu")[0]):
170+
model = LTX2AudioVideoGemmaTextEncoder(rngs=self.rngs)
171+
172+
state = nnx.state(model)
173+
eval_shapes = state.to_pure_dict()
174+
175+
print("Loading Connector Weights...")
176+
loaded_weights = load_connector_weights(
177+
pretrained_model_name_or_path=pretrained_model_name_or_path,
178+
eval_shapes=eval_shapes,
179+
device=self.device,
180+
hf_download=True
181+
)
182+
183+
print("Validating Connector Weights...")
184+
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
185+
print("Connector Weights Validated Successfully!")
186+
163187
if __name__ == "__main__":
164188
unittest.main()

0 commit comments

Comments
 (0)