Skip to content

Commit 71a9a76

Browse files
committed
added debug for values
1 parent 2c1fe43 commit 71a9a76

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,12 @@ def wrap_flash_attention(query, key, value):
294294
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
295295

296296
kv_padded_len = key.shape[2]
297-
print("KV padded len:", kv_padded_len)
298297
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
299-
jax.debug.print("KV indices: {x}", x=kv_indices)
300298
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
301-
jax.debug.print("KV segment ids: {x}", x=kv_segment_ids)
299+
num_ones_in_kv_segment_ids = jnp.sum(kv_segment_ids)
300+
jax.debug.print("KV padded len: {kv_len}, Number of 1s in KV segment ids: {num_ones}", kv_len=kv_padded_len, num_ones=num_ones_in_kv_segment_ids)
301+
302+
302303
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
303304

304305
# make_splash_mha is wrapped around shardmap and seq and head is already

0 commit comments

Comments
 (0)