Skip to content

Commit 317698e

Browse files
committed
fix in transformer
1 parent c45d444 commit 317698e

1 file changed

Lines changed: 23 additions & 23 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,29 +1009,29 @@ def __call__(
10091009
with self.conditional_named_scope("transformer_block"):
10101010
def scan_fn(carry, block):
10111011
hidden_states, audio_hidden_states, rngs_carry = carry
1012-
hidden_states_out, audio_hidden_states_out = block(
1013-
hidden_states=hidden_states,
1014-
audio_hidden_states=audio_hidden_states,
1015-
encoder_hidden_states=encoder_hidden_states,
1016-
audio_encoder_hidden_states=audio_encoder_hidden_states,
1017-
temb=temb,
1018-
temb_audio=temb_audio,
1019-
temb_ca_scale_shift=video_cross_attn_scale_shift,
1020-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1021-
temb_ca_gate=video_cross_attn_a2v_gate,
1022-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1023-
video_rotary_emb=video_rotary_emb,
1024-
audio_rotary_emb=audio_rotary_emb,
1025-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1026-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1027-
encoder_attention_mask=encoder_attention_mask,
1028-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1029-
)
1030-
return (
1031-
hidden_states_out.astype(hidden_states.dtype),
1032-
audio_hidden_states_out.astype(audio_hidden_states.dtype),
1033-
rngs_carry,
1034-
), None
1012+
hidden_states_out, audio_hidden_states_out = block(
1013+
hidden_states=hidden_states,
1014+
audio_hidden_states=audio_hidden_states,
1015+
encoder_hidden_states=encoder_hidden_states,
1016+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1017+
temb=temb,
1018+
temb_audio=temb_audio,
1019+
temb_ca_scale_shift=video_cross_attn_scale_shift,
1020+
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1021+
temb_ca_gate=video_cross_attn_a2v_gate,
1022+
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1023+
video_rotary_emb=video_rotary_emb,
1024+
audio_rotary_emb=audio_rotary_emb,
1025+
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1026+
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1027+
encoder_attention_mask=encoder_attention_mask,
1028+
audio_encoder_attention_mask=audio_encoder_attention_mask,
1029+
)
1030+
return (
1031+
hidden_states_out.astype(hidden_states.dtype),
1032+
audio_hidden_states_out.astype(audio_hidden_states.dtype),
1033+
rngs_carry,
1034+
), None
10351035

10361036
if self.scan_layers:
10371037
rematted_scan_fn = self.gradient_checkpoint.apply(

0 commit comments

Comments
 (0)