Skip to content

Commit dfd0452

Browse files
committed
fix
1 parent 4b2d267 commit dfd0452

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ def rename_for_ltx2_connector(key):
305305
key = key.replace("video_connector", "video_embeddings_connector")
306306
key = key.replace("audio_connector", "audio_embeddings_connector")
307307
key = key.replace("text_proj_in", "feature_extractor.linear")
308+
key = key.replace("audio_feature_extractor.linear", "audio_text_proj_in")
309+
key = key.replace("video_feature_extractor.linear", "video_text_proj_in")
308310

309311
if "transformer_blocks" in key:
310312
key = key.replace("transformer_blocks", "stacked_blocks")

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@ def __init__(
6363
proj_bias: bool = False,
6464
video_gated_attn: bool = False,
6565
audio_gated_attn: bool = False,
66+
audio_hidden_dim: Optional[int] = None,
67+
video_hidden_dim: Optional[int] = None,
6668
**kwargs,
6769
):
6870
gemma_dim = 3840 if video_caption_channels is not None else caption_channels
6971
input_dim = gemma_dim * text_proj_in_factor
7072

71-
v_dim = video_caption_channels if video_caption_channels is not None else caption_channels
72-
a_dim = audio_caption_channels if audio_caption_channels is not None else caption_channels
73+
v_dim = video_hidden_dim if video_hidden_dim is not None else (video_caption_channels if video_caption_channels is not None else caption_channels)
74+
a_dim = audio_hidden_dim if audio_hidden_dim is not None else (audio_caption_channels if audio_caption_channels is not None else caption_channels)
7375

7476
self.per_modality_projections = per_modality_projections
7577

0 commit comments

Comments
 (0)