Skip to content

Commit 5ed0991

Browse files
committed
pipeline fix
1 parent 0cdb17b commit 5ed0991

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,12 @@ def __call__(
10891089
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
10901090
hidden_dtype=hidden_states.dtype,
10911091
)
1092+
1093+
if video_cross_attn_scale_shift.shape[0] < batch_size:
1094+
video_cross_attn_scale_shift = jnp.repeat(video_cross_attn_scale_shift, batch_size // video_cross_attn_scale_shift.shape[0], axis=0)
1095+
if video_cross_attn_a2v_gate.shape[0] < batch_size:
1096+
video_cross_attn_a2v_gate = jnp.repeat(video_cross_attn_a2v_gate, batch_size // video_cross_attn_a2v_gate.shape[0], axis=0)
1097+
10921098
video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape(
10931099
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
10941100
)
@@ -1102,6 +1108,12 @@ def __call__(
11021108
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
11031109
hidden_dtype=audio_hidden_states.dtype,
11041110
)
1111+
1112+
if audio_cross_attn_scale_shift.shape[0] < batch_size:
1113+
audio_cross_attn_scale_shift = jnp.repeat(audio_cross_attn_scale_shift, batch_size // audio_cross_attn_scale_shift.shape[0], axis=0)
1114+
if audio_cross_attn_v2a_gate.shape[0] < batch_size:
1115+
audio_cross_attn_v2a_gate = jnp.repeat(audio_cross_attn_v2a_gate, batch_size // audio_cross_attn_v2a_gate.shape[0], axis=0)
1116+
11051117
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape(
11061118
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
11071119
)

0 commit comments

Comments
 (0)