We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 866c3d0 commit 91f1d8fCopy full SHA for 91f1d8f
1 file changed
src/maxdiffusion/models/ltx2/transformer_ltx2.py
@@ -1062,6 +1062,11 @@ def __call__(
1062
audio_sigma.flatten(),
1063
hidden_dtype=audio_hidden_states.dtype,
1064
)
1065
+ if temb_prompt.shape[0] < batch_size:
1066
+ temb_prompt = jnp.repeat(temb_prompt, batch_size // temb_prompt.shape[0], axis=0)
1067
+ if temb_prompt_audio.shape[0] < batch_size:
1068
+ temb_prompt_audio = jnp.repeat(temb_prompt_audio, batch_size // temb_prompt_audio.shape[0], axis=0)
1069
+
1070
temb_prompt = temb_prompt.reshape(batch_size, -1, temb_prompt.shape[-1])
1071
temb_prompt_audio = temb_prompt_audio.reshape(batch_size, -1, temb_prompt_audio.shape[-1])
1072
else:
0 commit comments