Skip to content

Commit a894502

Browse files
levskayaGoogle-ML-Automation
authored andcommitted
replace uses of pltpu.repeat with jnp.tile in pallas.
PiperOrigin-RevId: 864939936
1 parent c841dae commit a894502

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/MaxText/kernels/splash_attention_kernel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def _apply_mask_and_soft_cap(
622622

623623
repeats, rem = divmod(k_slice.size, NUM_LANES)
624624
assert rem == 0
625-
q_sequence = pltpu.repeat(q_sequence_ref[...], repeats, axis=1) # [bq, k_slice.size]
625+
q_sequence = jnp.tile(q_sequence_ref[...], (1, repeats)) # [bq, k_slice.size]
626626
else:
627627
assert q_sequence_ref.shape == (NUM_SUBLANES, bq)
628628

@@ -642,13 +642,13 @@ def _apply_mask_and_soft_cap(
642642
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
643643
if rem:
644644
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
645-
q_ids = pltpu.repeat(q_segment_ids_ref[:], repeats, axis=1) # [bq, bkv]
645+
q_ids = jnp.tile(q_segment_ids_ref[:], (1, repeats)) # [bq, bkv]
646646
else:
647647
assert bq == q_segment_ids_ref.shape[-1]
648648
repeats, rem = divmod(bq, NUM_LANES)
649649
if rem:
650650
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
651-
kv_ids = pltpu.repeat(kv_segment_ids_ref[k_slice, :], repeats, axis=1) # [k_slice, bq]
651+
kv_ids = jnp.tile(kv_segment_ids_ref[k_slice, :], (1, repeats)) # [k_slice, bq]
652652
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
653653
masks.append(q_ids == kv_ids)
654654

@@ -771,7 +771,7 @@ def body(kv_compute_index, _):
771771
if rem != 0:
772772
raise NotImplementedError(f"{bkv_compute=} should be a multiple of {NUM_LANES}")
773773

774-
s_curr = jnp.exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1))
774+
s_curr = jnp.exp(qk - jnp.tile(m_next, (1, bkv_repeats)))
775775
assert s_curr.shape == (bq, bkv_compute)
776776

777777
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
@@ -789,7 +789,7 @@ def body(kv_compute_index, _):
789789
v = v.astype(float32)
790790
o_curr = lax.dot_general(s_curr, v, sv_dims)
791791

792-
alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1)
792+
alpha_o = jnp.tile(alpha, (1, head_dim_v_repeats))
793793
o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr
794794

795795
@pl.when(should_run)
@@ -801,7 +801,7 @@ def run():
801801
@pl.when(j == grid_width - 1)
802802
def end():
803803
l = l_scratch_ref[...]
804-
l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=1)
804+
l_inv = jnp.tile(1.0 / l, (1, head_dim_v_repeats))
805805
o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype)
806806
if logsumexp_ref is not None:
807807
assert logsumexp_ref.shape == (bq, NUM_LANES)

0 commit comments

Comments
 (0)