Skip to content

Commit 195b451

Browse files
committed
Fuse normalization in kernel
Signed-off-by: Kunjan Patel <kunjan@ucla.edu>
1 parent 6fd09fe commit 195b451

3 files changed

Lines changed: 2272 additions & 19 deletions

File tree

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _ring_attention_forward(
119119
# Initial accumulator values
120120
o_shape = q.shape
121121
o_init = jnp.zeros(o_shape, dtype=jnp.float32)
122-
l_init = jnp.zeros((o_shape[0], o_shape[1]), jnp.float32)
122+
l_init = jnp.zeros((o_shape[0], o_shape[1], splash_kernel.NUM_LANES), jnp.float32)
123123
m_init = jnp.full_like(l_init, mask_value, dtype=jnp.float32)
124124

125125
def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]:
@@ -143,15 +143,21 @@ def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Arra
143143
v_current,
144144
segment_ids=segment_ids_current,
145145
sinks=sinks,
146+
m_init=m_prev,
147+
l_init=l_prev,
148+
o_init=o_prev
146149
)
147-
m_curr = stats["max_logits"].astype(jnp.float32)
148-
l_curr = stats["l_linear"].astype(jnp.float32)
149-
o_curr = out_curr.astype(jnp.float32)
150-
m_next = jnp.maximum(m_prev, m_curr)
151-
alpha = exp_fn(m_prev - m_next)
152-
beta = exp_fn(m_curr - m_next)
153-
l_next = alpha * l_prev + beta * l_curr
154-
o_next = alpha[..., None] * o_prev + beta[..., None] * o_curr
150+
m_next = stats["max_logits"].astype(jnp.float32)
151+
l_next = stats["l_linear"].astype(jnp.float32)
152+
o_next = out_curr.astype(jnp.float32)
153+
# m_curr = stats["max_logits"].astype(jnp.float32)
154+
# l_curr = stats["l_linear"].astype(jnp.float32)
155+
# o_curr = out_curr.astype(jnp.float32)
156+
# m_next = jnp.maximum(m_prev, m_curr)
157+
# alpha = exp_fn(m_prev - m_next)
158+
# beta = exp_fn(m_curr - m_next)
159+
# l_next = alpha * l_prev + beta * l_curr
160+
# o_next = alpha[..., None] * o_prev + beta[..., None] * o_curr
155161
return (m_next, l_next, o_next, k_next, v_next, segment_ids_next), None
156162

157163
# Use lax.scan to get the final carry AND the collected sequence of (k,v)
@@ -165,12 +171,25 @@ def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Arra
165171
unroll=True,
166172
) # type: ignore[arg-type]
167173
# Final normalization
168-
assert l_final.dtype == jnp.float32
169-
l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final)
174+
# assert l_final.dtype == jnp.float32
175+
# l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final)
176+
# out = (o_final * l_inv[..., None]).astype(q.dtype)
177+
# # Final logsumexp for residuals
178+
# lse = log_fn(l_final) + m_final
179+
# lse = jnp.where(l_final == 0.0, mask_value, lse)
180+
# Final normalization (Slice off NUM_LANES down to 2D)
181+
l_final_2d = l_final[..., 0]
182+
m_final_2d = m_final[..., 0]
183+
184+
assert l_final_2d.dtype == jnp.float32
185+
l_inv = jnp.where(l_final_2d == 0.0, 0.0, 1.0 / l_final_2d)
170186
out = (o_final * l_inv[..., None]).astype(q.dtype)
187+
171188
# Final logsumexp for residuals
172-
lse = log_fn(l_final) + m_final
173-
lse = jnp.where(l_final == 0.0, mask_value, lse)
189+
lse = log_fn(l_final_2d) + m_final_2d
190+
lse = jnp.where(l_final_2d == 0.0, mask_value, lse)
191+
192+
174193

175194
return out, (lse, m_final)
176195

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def flash_attention_kernel(
297297
mask_ref,
298298
q_sequence_ref,
299299
max_logit_value_ref,
300+
m_init_ref,
301+
l_init_ref,
302+
o_init_ref,
300303
# Outputs
301304
o_ref,
302305
logsumexp_ref,
@@ -348,15 +351,24 @@ def flash_attention_kernel(
348351

349352
@pl.when(should_initialize)
350353
def init():
351-
o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref)
352-
354+
355+
if o_init_ref is not None:
356+
o_scratch_ref[...] = o_init_ref[...]
357+
else:
358+
o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref)
353359
sink = None
354360
if sinks_ref is not None:
355361
sink = sinks_ref[0, h].astype(m_scratch_ref.dtype)
356362

357363
if sinks_ref is None and max_logit_estimate is None:
358-
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
359-
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
364+
# m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
365+
# l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
366+
if m_init_ref is not None and l_init_ref is not None:
367+
m_scratch_ref[...] = m_init_ref[...]
368+
l_scratch_ref[...] = l_init_ref[...]
369+
else:
370+
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
371+
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
360372
elif sinks_ref is None and max_logit_estimate is not None:
361373
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, max_logit_estimate)
362374
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
@@ -680,6 +692,7 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
680692
else:
681693
in_specs.append(None)
682694

695+
in_specs += [None, None, None]
683696
out_shapes = [
684697
jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), q.dtype),
685698
]
@@ -815,6 +828,9 @@ def _fwd_cost_estimate(
815828
mask_info.partial_mask_blocks,
816829
q_sequence,
817830
max_logit_value,
831+
None, # m_init
832+
None, # l_init
833+
None, # o_init
818834
)
819835
out, logsumexp, l_linear, max_logits = all_out
820836

@@ -872,6 +888,9 @@ def _splash_attention_forward_ring_raw(
872888
mask_function: MaskFunctionType | None,
873889
fwd_mask_sparsity: float,
874890
max_logit_value: jax.Array | None = None,
891+
m_init: jax.Array = None,
892+
l_init: jax.Array = None,
893+
o_init: jax.Array = None,
875894
) -> tuple[jax.Array, dict[str, jax.Array]]:
876895
"""Ring-specific forward path that returns pre-reciprocal fp32 accumulators.
877896
@@ -1039,6 +1058,19 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
10391058
in_specs.append(None)
10401059

10411060
logsumexp_index_map = unravel(lambda h, i, j, *_: (h, i, 0))
1061+
init = m_init is not None and l_init is not None and o_init is not None
1062+
if init:
1063+
m_index_map = unravel(lambda h, i, j: (h, i, 0))
1064+
l_index_map = unravel(lambda h, i, j: (h, i, 0))
1065+
out_init_index_map = out_index_map
1066+
in_specs += [
1067+
pl.BlockSpec((None, bq, NUM_LANES), m_index_map),
1068+
pl.BlockSpec((None, bq, NUM_LANES), l_index_map),
1069+
pl.BlockSpec((None, bq, head_dim_v), out_init_index_map),
1070+
]
1071+
else:
1072+
in_specs += [None, None, None]
1073+
10421074
out_shapes = [
10431075
jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), jnp.float32),
10441076
None,
@@ -1143,6 +1175,9 @@ def _fwd_cost_estimate(
11431175
mask_info.partial_mask_blocks,
11441176
q_sequence,
11451177
max_logit_value,
1178+
m_init,
1179+
l_init,
1180+
o_init,
11461181
)
11471182
out_linear, _, l_linear, max_logits = all_out
11481183

@@ -1151,11 +1186,16 @@ def init_if_empty(x: jax.Array, value: float) -> jax.Array:
11511186
return x
11521187
return jnp.where(is_empty_attention_block, value, x)
11531188

1189+
# out_linear = init_if_empty(out_linear, 0.0)
1190+
# assert l_linear is not None
1191+
# assert max_logits is not None
11541192
out_linear = init_if_empty(out_linear, 0.0)
11551193
assert l_linear is not None
11561194
assert max_logits is not None
1157-
l_linear = init_if_empty(l_linear[..., 0], 0.0)
1158-
max_logits = init_if_empty(max_logits[..., 0], mask_value)
1195+
l_linear = init_if_empty(l_linear, 0.0)
1196+
max_logits = init_if_empty(max_logits, mask_value)
1197+
# l_linear = init_if_empty(l_linear[..., 0], 0.0)
1198+
# max_logits = init_if_empty(max_logits[..., 0], mask_value)
11591199

11601200
stats = {"l_linear": l_linear, "max_logits": max_logits}
11611201
stats = jax.tree.map(jax.lax.stop_gradient, stats)

0 commit comments

Comments
 (0)