@@ -297,6 +297,9 @@ def flash_attention_kernel(
297297 mask_ref ,
298298 q_sequence_ref ,
299299 max_logit_value_ref ,
300+ m_init_ref ,
301+ l_init_ref ,
302+ o_init_ref ,
300303 # Outputs
301304 o_ref ,
302305 logsumexp_ref ,
@@ -348,15 +351,24 @@ def flash_attention_kernel(
348351
349352 @pl .when (should_initialize )
350353 def init ():
351- o_scratch_ref [...] = jnp .zeros_like (o_scratch_ref )
352-
354+
355+ if o_init_ref is not None :
356+ o_scratch_ref [...] = o_init_ref [...]
357+ else :
358+ o_scratch_ref [...] = jnp .zeros_like (o_scratch_ref )
353359 sink = None
354360 if sinks_ref is not None :
355361 sink = sinks_ref [0 , h ].astype (m_scratch_ref .dtype )
356362
357363 if sinks_ref is None and max_logit_estimate is None :
358- m_scratch_ref [...] = jnp .full_like (m_scratch_ref , mask_value )
359- l_scratch_ref [...] = jnp .zeros_like (l_scratch_ref )
364+ # m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
365+ # l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
366+ if m_init_ref is not None and l_init_ref is not None :
367+ m_scratch_ref [...] = m_init_ref [...]
368+ l_scratch_ref [...] = l_init_ref [...]
369+ else :
370+ m_scratch_ref [...] = jnp .full_like (m_scratch_ref , mask_value )
371+ l_scratch_ref [...] = jnp .zeros_like (l_scratch_ref )
360372 elif sinks_ref is None and max_logit_estimate is not None :
361373 m_scratch_ref [...] = jnp .full_like (m_scratch_ref , max_logit_estimate )
362374 l_scratch_ref [...] = jnp .zeros_like (l_scratch_ref )
@@ -680,6 +692,7 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
680692 else :
681693 in_specs .append (None )
682694
695+ in_specs += [None , None , None ]
683696 out_shapes = [
684697 jax .ShapeDtypeStruct ((num_q_heads , q_seq_len , head_dim_v ), q .dtype ),
685698 ]
@@ -815,6 +828,9 @@ def _fwd_cost_estimate(
815828 mask_info .partial_mask_blocks ,
816829 q_sequence ,
817830 max_logit_value ,
831+ None , # m_init
832+ None , # l_init
833+ None , # o_init
818834 )
819835 out , logsumexp , l_linear , max_logits = all_out
820836
@@ -872,6 +888,9 @@ def _splash_attention_forward_ring_raw(
872888 mask_function : MaskFunctionType | None ,
873889 fwd_mask_sparsity : float ,
874890 max_logit_value : jax .Array | None = None ,
891+ m_init : jax .Array = None ,
892+ l_init : jax .Array = None ,
893+ o_init : jax .Array = None ,
875894) -> tuple [jax .Array , dict [str , jax .Array ]]:
876895 """Ring-specific forward path that returns pre-reciprocal fp32 accumulators.
877896
@@ -1039,6 +1058,19 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
10391058 in_specs .append (None )
10401059
10411060 logsumexp_index_map = unravel (lambda h , i , j , * _ : (h , i , 0 ))
1061+ init = m_init is not None and l_init is not None and o_init is not None
1062+ if init :
1063+ m_index_map = unravel (lambda h , i , j : (h , i , 0 ))
1064+ l_index_map = unravel (lambda h , i , j : (h , i , 0 ))
1065+ out_init_index_map = out_index_map
1066+ in_specs += [
1067+ pl .BlockSpec ((None , bq , NUM_LANES ), m_index_map ),
1068+ pl .BlockSpec ((None , bq , NUM_LANES ), l_index_map ),
1069+ pl .BlockSpec ((None , bq , head_dim_v ), out_init_index_map ),
1070+ ]
1071+ else :
1072+ in_specs += [None , None , None ]
1073+
10421074 out_shapes = [
10431075 jax .ShapeDtypeStruct ((num_q_heads , q_seq_len , head_dim_v ), jnp .float32 ),
10441076 None ,
@@ -1143,6 +1175,9 @@ def _fwd_cost_estimate(
11431175 mask_info .partial_mask_blocks ,
11441176 q_sequence ,
11451177 max_logit_value ,
1178+ m_init ,
1179+ l_init ,
1180+ o_init ,
11461181 )
11471182 out_linear , _ , l_linear , max_logits = all_out
11481183
@@ -1151,11 +1186,16 @@ def init_if_empty(x: jax.Array, value: float) -> jax.Array:
11511186 return x
11521187 return jnp .where (is_empty_attention_block , value , x )
11531188
1189+ # out_linear = init_if_empty(out_linear, 0.0)
1190+ # assert l_linear is not None
1191+ # assert max_logits is not None
11541192 out_linear = init_if_empty (out_linear , 0.0 )
11551193 assert l_linear is not None
11561194 assert max_logits is not None
1157- l_linear = init_if_empty (l_linear [..., 0 ], 0.0 )
1158- max_logits = init_if_empty (max_logits [..., 0 ], mask_value )
1195+ l_linear = init_if_empty (l_linear , 0.0 )
1196+ max_logits = init_if_empty (max_logits , mask_value )
1197+ # l_linear = init_if_empty(l_linear[..., 0], 0.0)
1198+ # max_logits = init_if_empty(max_logits[..., 0], mask_value)
11591199
11601200 stats = {"l_linear" : l_linear , "max_logits" : max_logits }
11611201 stats = jax .tree .map (jax .lax .stop_gradient , stats )
0 commit comments