Skip to content

Commit a918bea

Browse files
committed
cleanup
1 parent 81aafe1 commit a918bea

1 file changed

Lines changed: 4 additions & 24 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -404,24 +404,17 @@ def rename_for_ltx2_connector(key):
404404
key = key.replace("audio_connector", "audio_embeddings_connector")
405405
key = key.replace("text_proj_in", "feature_extractor.linear")
406406

407-
# Transformer blocks mapping
408407
if "transformer_blocks" in key:
409408
key = key.replace("transformer_blocks", "stacked_blocks")
410-
# Handle FF
411409
key = key.replace("ff.net.0.proj", "ff.proj1")
412410
key = key.replace("ff.net.2", "ff.proj2")
413-
# Handle to_out
414411
key = key.replace("to_out.0", "to_out")
415412

416-
# Validation/Weight suffix
417413
if key.endswith(".weight"):
418-
# Check if it's a norm with usage_scale=True (attn norms)
419414
if "norm_q" in key or "norm_k" in key:
420415
key = key.replace(".weight", ".scale")
421-
# Check if it's a norm with usage_scale=False (block norms) -> No, these don't exist in checkpoint!
422416
else:
423417
key = key.replace(".weight", ".kernel")
424-
425418
return key
426419

427420
def load_connector_weights(
@@ -434,8 +427,7 @@ def load_connector_weights(
434427
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
435428
flax_state_dict = {}
436429
cpu = jax.local_devices(backend="cpu")[0]
437-
438-
# Store stacked weights: grouped_weights[connector][param_name] = {layer_idx: tensor}
430+
439431
grouped_weights = {
440432
"video_embeddings_connector": {},
441433
"audio_embeddings_connector": {}
@@ -444,22 +436,17 @@ def load_connector_weights(
444436
for pt_key, tensor in tensors.items():
445437
key = rename_for_ltx2_connector(pt_key)
446438

447-
# Check for transpose (Linear layers)
448439
if key.endswith(".kernel"):
449-
if tensor.ndim == 2:
450-
tensor = tensor.transpose(1, 0)
451-
440+
if tensor.ndim == 2:
441+
tensor = tensor.transpose(1, 0)
442+
452443
if "stacked_blocks" in key:
453-
# key format: {connector}.stacked_blocks.{layer_idx}.{rest}
454444
parts = key.split(".")
455-
# Find stacked_blocks index
456445
try:
457446
sb_index = parts.index("stacked_blocks")
458447
layer_idx = int(parts[sb_index + 1])
459448
connector = parts[0]
460449

461-
# Reconstruct param name without layer index
462-
# e.g. video_embeddings_connector.stacked_blocks.attn1...
463450
param_parts = parts[:sb_index+1] + parts[sb_index+2:]
464451
param_name = tuple(param_parts)
465452

@@ -471,10 +458,8 @@ def load_connector_weights(
471458
except (ValueError, IndexError):
472459
pass
473460

474-
# Non-stacked keys
475461
key_tuple = tuple(key.split("."))
476462

477-
# Handle int conversion for parts
478463
final_key_tuple = []
479464
for p in key_tuple:
480465
if p.isdigit(): final_key_tuple.append(int(p))
@@ -483,15 +468,11 @@ def load_connector_weights(
483468

484469
flax_state_dict[final_key_tuple] = jax.device_put(tensor, device=cpu)
485470

486-
# Process grouped weights
487471
for connector, params in grouped_weights.items():
488472
for param_name, layers in params.items():
489-
# Sort by layer index and stack
490473
sorted_layers = sorted(layers.keys())
491-
# Assuming contiguous layers 0..N-1
492474
stacked_tensor = jnp.stack([layers[i] for i in sorted_layers], axis=0)
493475

494-
# Param name tuple
495476
final_param_name = []
496477
for p in param_name:
497478
if isinstance(p, str) and p.isdigit(): final_param_name.append(int(p))
@@ -500,7 +481,6 @@ def load_connector_weights(
500481

501482
flax_state_dict[final_param_name] = jax.device_put(stacked_tensor, device=cpu)
502483

503-
# Clean up and return
504484
del tensors
505485
jax.clear_caches()
506486
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)