Skip to content

Commit 7066e2f

Browse files
committed
added debug for values
1 parent 4836217 commit 7066e2f

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def wrap_flash_attention(query, key, value):
283283

284284
block_kv = max(*block_kv_sizes)
285285
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
286+
print("Key seq len after padding:", key_seq_len)
286287
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
287288

288289
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
@@ -293,8 +294,10 @@ def wrap_flash_attention(query, key, value):
293294
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
294295

295296
kv_padded_len = key.shape[2]
297+
print("KV padded len:", kv_padded_len)
296298
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
297299
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
300+
print("KV segment ids:", kv_segment_ids)
298301
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
299302

300303
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -1008,8 +1011,10 @@ def __call__(
10081011
query_proj = self.query(hidden_states)
10091012
with jax.named_scope("key_proj"):
10101013
key_proj = self.key(encoder_hidden_states)
1014+
print("key_proj shape:", key_proj.shape)
10111015
with jax.named_scope("value_proj"):
10121016
value_proj = self.value(encoder_hidden_states)
1017+
print("value_proj shape:", value_proj.shape)
10131018

10141019
if self.qk_norm:
10151020
with self.conditional_named_scope("attn_q_norm"):

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,10 @@ def _get_t5_prompt_embeds(
403403
_, seq_len, _ = prompt_embeds.shape
404404
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
405405
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
406-
406+
print("Prompt embeds shape:")
407+
print(prompt_embeds.shape)
408+
print("Prompt embeds:")
409+
print(prompt_embeds)
407410
return prompt_embeds
408411

409412
def encode_prompt(

0 commit comments

Comments
 (0)