Skip to content

Commit 0a4549b

Browse files
committed
transformer weight loading for scan false and timestepprojection issue
1 parent b02383b commit 0a4549b

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,15 +698,15 @@ def __init__(
698698
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
699699
rngs=rngs,
700700
in_features=self.caption_channels,
701-
hidden_size=inner_dim,
701+
hidden_size=self.cross_attention_dim,
702702
embedding_dim=inner_dim,
703703
dtype=self.dtype,
704704
weights_dtype=self.weights_dtype,
705705
)
706706
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
707707
rngs=rngs,
708708
in_features=self.audio_caption_channels,
709-
hidden_size=audio_inner_dim,
709+
hidden_size=self.audio_cross_attention_dim,
710710
embedding_dim=audio_inner_dim,
711711
dtype=self.dtype,
712712
weights_dtype=self.weights_dtype,
@@ -1050,10 +1050,10 @@ def __call__(
10501050
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
10511051

10521052
# 4. Prepare prompt embeddings
1053-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
1053+
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
10541054
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
10551055

1056-
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
1056+
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep)
10571057
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
10581058

10591059
# 5. Run transformer blocks

0 commit comments

Comments
 (0)