@@ -622,7 +622,7 @@ def _apply_mask_and_soft_cap(
622622
623623 repeats , rem = divmod (k_slice .size , NUM_LANES )
624624 assert rem == 0
625- q_sequence = pltpu . repeat (q_sequence_ref [...], repeats , axis = 1 ) # [bq, k_slice.size]
625+ q_sequence = jnp . tile (q_sequence_ref [...], ( 1 , repeats ) ) # [bq, k_slice.size]
626626 else :
627627 assert q_sequence_ref .shape == (NUM_SUBLANES , bq )
628628
@@ -642,13 +642,13 @@ def _apply_mask_and_soft_cap(
642642 repeats , rem = divmod (kv_ids .shape [1 ], NUM_LANES )
643643 if rem :
644644 raise NotImplementedError (f"block_kv must be a multiple of { NUM_LANES } " )
645- q_ids = pltpu . repeat (q_segment_ids_ref [:], repeats , axis = 1 ) # [bq, bkv]
645+ q_ids = jnp . tile (q_segment_ids_ref [:], ( 1 , repeats ) ) # [bq, bkv]
646646 else :
647647 assert bq == q_segment_ids_ref .shape [- 1 ]
648648 repeats , rem = divmod (bq , NUM_LANES )
649649 if rem :
650650 raise NotImplementedError (f"block_q must be a multiple of { NUM_LANES } " )
651- kv_ids = pltpu . repeat (kv_segment_ids_ref [k_slice , :], repeats , axis = 1 ) # [k_slice, bq]
651+ kv_ids = jnp . tile (kv_segment_ids_ref [k_slice , :], ( 1 , repeats ) ) # [k_slice, bq]
652652 q_ids = q_segment_ids_ref [:1 , :] # [1, bq]
653653 masks .append (q_ids == kv_ids )
654654
@@ -771,7 +771,7 @@ def body(kv_compute_index, _):
771771 if rem != 0 :
772772 raise NotImplementedError (f"{ bkv_compute = } should be a multiple of { NUM_LANES } " )
773773
774- s_curr = jnp .exp (qk - pltpu . repeat (m_next , bkv_repeats , axis = 1 ))
774+ s_curr = jnp .exp (qk - jnp . tile (m_next , ( 1 , bkv_repeats ) ))
775775 assert s_curr .shape == (bq , bkv_compute )
776776
777777 l_curr = jax .lax .broadcast_in_dim (s_curr .sum (axis = - 1 ), l_prev .shape , (0 ,))
@@ -789,7 +789,7 @@ def body(kv_compute_index, _):
789789 v = v .astype (float32 )
790790 o_curr = lax .dot_general (s_curr , v , sv_dims )
791791
792- alpha_o = pltpu . repeat (alpha , head_dim_v_repeats , axis = 1 )
792+ alpha_o = jnp . tile (alpha , ( 1 , head_dim_v_repeats ) )
793793 o_scratch_ref [:] = alpha_o * o_scratch_ref [:] + o_curr
794794
795795 @pl .when (should_run )
@@ -801,7 +801,7 @@ def run():
801801 @pl .when (j == grid_width - 1 )
802802 def end ():
803803 l = l_scratch_ref [...]
804- l_inv = pltpu . repeat (1.0 / l , head_dim_v_repeats , axis = 1 )
804+ l_inv = jnp . tile (1.0 / l , ( 1 , head_dim_v_repeats ) )
805805 o_ref [...] = (o_scratch_ref [...] * l_inv ).astype (o_ref .dtype )
806806 if logsumexp_ref is not None :
807807 assert logsumexp_ref .shape == (bq , NUM_LANES )
0 commit comments