File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -419,6 +419,8 @@ def _get_t5_prompt_embeds(
419419 num_videos_per_prompt : int = 1 ,
420420 max_sequence_length : int = 226 ,
421421 ):
422+ jax .debug .print ("Prompt Shape" )
423+ jax .debug .print (prompt .shape )
422424 prompt = [prompt ] if isinstance (prompt , str ) else prompt
423425 prompt = [prompt_clean (u ) for u in prompt ]
424426 batch_size = len (prompt )
@@ -433,6 +435,12 @@ def _get_t5_prompt_embeds(
433435 return_tensors = "pt" ,
434436 )
435437 text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
438+ jax .debug .print ("Text Input IDS" )
439+ jax .debug .print (text_input_ids )
440+ jax .debug .print (text_input_ids .shape )
441+ jax .debug .print ("Mask" )
442+ jax .debug .print (mask )
443+ jax .debug .print (mask .shape )
436444 seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
437445 prompt_embeds = self .text_encoder (text_input_ids , mask ).last_hidden_state
438446 prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
You can’t perform that action at this time.
0 commit comments