We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dc4f451 commit 0d559d5Copy full SHA for 0d559d5
1 file changed
src/maxdiffusion/pipelines/wan/wan_pipeline.py
@@ -436,12 +436,12 @@ def _get_t5_prompt_embeds(
436
return_tensors="pt",
437
)
438
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
439
- jax.debug.print("Text Input IDS")
440
- jax.debug.print(text_input_ids)
441
- jax.debug.print(text_input_ids.shape)
442
- jax.debug.print("Mask")
443
- jax.debug.print(mask)
444
- jax.debug.print(mask.shape)
+ print("Text Input IDS")
+ print(text_input_ids)
+ print(text_input_ids.shape)
+ print("Mask")
+ print(mask)
+ print(mask.shape)
445
seq_lens = mask.gt(0).sum(dim=1).long()
446
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
447
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
0 commit comments