Skip to content

Commit 91f1d8f

Browse files
committed
transformer fix
1 parent 866c3d0 commit 91f1d8f

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,11 @@ def __call__(
10621062
audio_sigma.flatten(),
10631063
hidden_dtype=audio_hidden_states.dtype,
10641064
)
1065+
if temb_prompt.shape[0] < batch_size:
1066+
temb_prompt = jnp.repeat(temb_prompt, batch_size // temb_prompt.shape[0], axis=0)
1067+
if temb_prompt_audio.shape[0] < batch_size:
1068+
temb_prompt_audio = jnp.repeat(temb_prompt_audio, batch_size // temb_prompt_audio.shape[0], axis=0)
1069+
10651070
temb_prompt = temb_prompt.reshape(batch_size, -1, temb_prompt.shape[-1])
10661071
temb_prompt_audio = temb_prompt_audio.reshape(batch_size, -1, temb_prompt_audio.shape[-1])
10671072
else:

0 commit comments

Comments
 (0)