Skip to content

Commit 7b8efca

Browse files
committed
some debug added to t5_prompt_embeds
1 parent 0fb882d commit 7b8efca

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,8 @@ def _get_t5_prompt_embeds(
419419
num_videos_per_prompt: int = 1,
420420
max_sequence_length: int = 226,
421421
):
422+
jax.debug.print("Prompt Shape")
423+
jax.debug.print(prompt.shape)
422424
prompt = [prompt] if isinstance(prompt, str) else prompt
423425
prompt = [prompt_clean(u) for u in prompt]
424426
batch_size = len(prompt)
@@ -433,6 +435,12 @@ def _get_t5_prompt_embeds(
433435
return_tensors="pt",
434436
)
435437
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
438+
jax.debug.print("Text Input IDS")
439+
jax.debug.print(text_input_ids)
440+
jax.debug.print(text_input_ids.shape)
441+
jax.debug.print("Mask")
442+
jax.debug.print(mask)
443+
jax.debug.print(mask.shape)
436444
seq_lens = mask.gt(0).sum(dim=1).long()
437445
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
438446
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]

0 commit comments

Comments
 (0)