Skip to content

Commit 0d559d5

Browse files
committed
some debug added to t5_prompt_embeds
1 parent dc4f451 commit 0d559d5

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,12 @@ def _get_t5_prompt_embeds(
436436
return_tensors="pt",
437437
)
438438
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
439-
jax.debug.print("Text Input IDS")
440-
jax.debug.print(text_input_ids)
441-
jax.debug.print(text_input_ids.shape)
442-
jax.debug.print("Mask")
443-
jax.debug.print(mask)
444-
jax.debug.print(mask.shape)
439+
print("Text Input IDS")
440+
print(text_input_ids)
441+
print(text_input_ids.shape)
442+
print("Mask")
443+
print(mask)
444+
print(mask.shape)
445445
seq_lens = mask.gt(0).sum(dim=1).long()
446446
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
447447
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]

0 commit comments

Comments
 (0)