Skip to content

Commit 05700d5

Browse files
committed
Move MaxText Kernels
1 parent f938dcb commit 05700d5

14 files changed

Lines changed: 46 additions & 77 deletions

src/MaxText/layers/attention_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@
6969
Q_LENGTH_NO_EXP,
7070
)
7171

72-
from MaxText.kernels import jax_flash_attention
73-
from MaxText.kernels.ragged_attention import ragged_gqa
74-
from MaxText.kernels.ragged_attention import ragged_mha
7572
from MaxText.layers import nnx_wrappers
7673
from MaxText.layers.initializers import variable_to_logically_partitioned
7774
from MaxText.layers.quantizations import AqtQuantization as Quant
7875
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
7976
from maxtext.inference import page_manager
8077
from maxtext.inference.kvcache import KVQuant, KVTensor
78+
from maxtext.kernels.attention import jax_flash_attention
79+
from maxtext.kernels.attention.ragged_attention import ragged_gqa
80+
from maxtext.kernels.attention.ragged_attention import ragged_mha
8181
from maxtext.utils import max_utils
8282
import numpy as np
8383
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import jax
2323
import jax.numpy as jnp
24-
from MaxText.kernels import megablox
25-
from MaxText.kernels import sort_activations
24+
from maxtext.kernels import megablox
25+
from maxtext.kernels import sort_activations
2626
from MaxText.layers import attention_op
2727
from MaxText.layers import quantizations
2828

src/MaxText/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from MaxText import common_types as ctypes
3333
from MaxText.common_types import ShardMode
3434
from MaxText.sharding import maybe_shard_with_logical, create_sharding
35-
from MaxText.kernels import megablox as mblx
3635
from MaxText.sharding import logical_to_mesh_axes
3736
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
3837
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
38+
from maxtext.kernels import megablox as mblx
3939
from maxtext.utils import max_logging
4040
from maxtext.utils import max_utils
4141
import numpy as np
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/kernels/jax_flash_attention.py renamed to src/maxtext/kernels/attention/jax_flash_attention.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717

1818
import jax
1919
import jax.numpy as jnp
20-
from MaxText.kernels import splash_attention_kernel
20+
from maxtext.kernels.attention import splash_attention_kernel
2121

2222
SegmentIds = splash_attention_kernel.SegmentIds
2323

@@ -77,38 +77,34 @@ def flash_attention_block_masked(
7777
v_head_dim_size = v.shape[-1]
7878
data_type = q.dtype
7979
q_groups = num_q_heads // num_kv_heads
80-
q = q.reshape((
81-
batch_size,
82-
num_kv_heads,
83-
q_groups,
84-
q_seq_len,
85-
qk_head_dim_size,
86-
))
80+
q = q.reshape(
81+
(
82+
batch_size,
83+
num_kv_heads,
84+
q_groups,
85+
q_seq_len,
86+
qk_head_dim_size,
87+
)
88+
)
8789

8890
# Calculate the number of key/value and query blocks.
8991
num_kv_blocks = kv_seq_len // block_kv
9092
num_q_blocks = q_seq_len // block_q
9193

9294
# Before applying the segment mask, we need to broadcast the mask in batch
9395
# dimension since we have same logic for all batches.
94-
mask_full = jnp.broadcast_to(
95-
mask[None, :, :], (batch_size, q_seq_len, kv_seq_len)
96-
)
96+
mask_full = jnp.broadcast_to(mask[None, :, :], (batch_size, q_seq_len, kv_seq_len))
9797

9898
if segment_ids is not None:
9999
segment_ids_q = segment_ids.q[:, :, None]
100100
segment_ids_kv = segment_ids.kv[:, None, :]
101101
mask_full = jnp.logical_and(mask_full, segment_ids_q == segment_ids_kv)
102-
mask_blocked = jax.jit(mask_blocker, static_argnums=[1, 2])(
103-
mask_full, block_q, block_kv
104-
)
102+
mask_blocked = jax.jit(mask_blocker, static_argnums=[1, 2])(mask_full, block_q, block_kv)
105103

106104
# Initialize `l` (logsumexp) and `m` (max_logits) for the online softmax.
107105
# `l` is initialized to 0 since no blocks have been processed yet and the sum
108106
# is 0.
109-
l = jnp.zeros(
110-
(batch_size, num_kv_heads, q_groups, q_seq_len), dtype=data_type
111-
)
107+
l = jnp.zeros((batch_size, num_kv_heads, q_groups, q_seq_len), dtype=data_type)
112108
# `m` is initialized to the mask_value so that the first block's maximum logit
113109
# correctly becomes the running maximum.
114110
m = jnp.full(
@@ -144,15 +140,9 @@ def inner_loop_body(i, carried_inner):
144140
# Calculates the attention computation (Q@K.T)@V with online softmax for
145141
# the current query and key/value blocks.
146142
def compute_attention_block(output, l, m):
147-
output_i_slice = jax.lax.dynamic_slice_in_dim(
148-
output, i * block_q, block_q, axis=-2
149-
)
150-
l_i_slice = jax.lax.dynamic_slice_in_dim(
151-
l, i * block_q, block_q, axis=-1
152-
)
153-
m_i_slice = jax.lax.dynamic_slice_in_dim(
154-
m, i * block_q, block_q, axis=-1
155-
)
143+
output_i_slice = jax.lax.dynamic_slice_in_dim(output, i * block_q, block_q, axis=-2)
144+
l_i_slice = jax.lax.dynamic_slice_in_dim(l, i * block_q, block_q, axis=-1)
145+
m_i_slice = jax.lax.dynamic_slice_in_dim(m, i * block_q, block_q, axis=-1)
156146
s_i_j = jnp.einsum(
157147
"bxhqc,bxkc->bxhqk",
158148
q_slice,
@@ -183,25 +173,19 @@ def compute_attention_block(output, l, m):
183173
l_i_new = m_i_difference * l_i_slice + m_i_j_difference * l_i_j
184174

185175
divider = l_i_new[..., None]
186-
numerator = l_i_slice[..., None] * m_i_difference[
176+
numerator = l_i_slice[..., None] * m_i_difference[..., None] * output_i_slice + m_i_j_difference[
187177
..., None
188-
] * output_i_slice + m_i_j_difference[..., None] * jnp.einsum(
178+
] * jnp.einsum(
189179
"bxhqk,bxkc->bxhqc",
190180
p_i_j,
191181
v_j_slice,
192182
preferred_element_type=data_type,
193183
)
194184

195185
output_i_slice_new = numerator / divider
196-
output = jax.lax.dynamic_update_index_in_dim(
197-
output, output_i_slice_new, i * block_q, axis=-2
198-
)
199-
l = jax.lax.dynamic_update_index_in_dim(
200-
l, l_i_new, i * block_q, axis=-1
201-
)
202-
m = jax.lax.dynamic_update_index_in_dim(
203-
m, m_i_new, i * block_q, axis=-1
204-
)
186+
output = jax.lax.dynamic_update_index_in_dim(output, output_i_slice_new, i * block_q, axis=-2)
187+
l = jax.lax.dynamic_update_index_in_dim(l, l_i_new, i * block_q, axis=-1)
188+
m = jax.lax.dynamic_update_index_in_dim(m, m_i_new, i * block_q, axis=-1)
205189
return output, l, m
206190

207191
def identity(output, l, m):
@@ -210,9 +194,7 @@ def identity(output, l, m):
210194
return output, l, m
211195

212196
batch_size = mask_blocked.shape[0]
213-
mask_i_j_slice = jax.lax.dynamic_slice(
214-
mask_blocked, (0, i, j), (batch_size, 1, 1)
215-
)
197+
mask_i_j_slice = jax.lax.dynamic_slice(mask_blocked, (0, i, j), (batch_size, 1, 1))
216198
# The compute_attention_block should be executed if at least one element
217199
# in the slice is non-zero, meaning at least one batch requires work for
218200
# this block.
@@ -227,15 +209,11 @@ def identity(output, l, m):
227209

228210
return output, l, m
229211

230-
output, l, m = jax.lax.fori_loop(
231-
0, num_q_blocks, inner_loop_body, (output, l, m), unroll=True
232-
)
212+
output, l, m = jax.lax.fori_loop(0, num_q_blocks, inner_loop_body, (output, l, m), unroll=True)
233213

234214
return (output, l, m)
235215

236-
output, l, m = jax.lax.fori_loop(
237-
0, num_kv_blocks, outer_loop_body, (output, l, m), unroll=True
238-
)
216+
output, l, m = jax.lax.fori_loop(0, num_kv_blocks, outer_loop_body, (output, l, m), unroll=True)
239217

240218
# Reshape the output to drop the size one dimension at index 2,
241219
# which corresponds to `num_q_heads // num_kv_heads` when
@@ -268,17 +246,11 @@ def mask_blocker(mask: jnp.ndarray, block_q: int, block_kv: int) -> jnp.ndarray:
268246
batch_size, q_seq_len, kv_seq_len = mask.shape
269247

270248
if q_seq_len % block_q != 0:
271-
raise ValueError(
272-
f"q_seq_len {q_seq_len} must be divisible by block_q {block_q}"
273-
)
249+
raise ValueError(f"q_seq_len {q_seq_len} must be divisible by block_q {block_q}")
274250
if kv_seq_len % block_kv != 0:
275-
raise ValueError(
276-
f"kv_seq_len {kv_seq_len} must be divisible by block_kv {block_kv}"
277-
)
251+
raise ValueError(f"kv_seq_len {kv_seq_len} must be divisible by block_kv {block_kv}")
278252
q_blocks = q_seq_len // block_q
279253
kv_blocks = kv_seq_len // block_kv
280254

281-
blocked_mask = mask.reshape(
282-
batch_size, q_blocks, block_q, kv_blocks, block_kv
283-
)
255+
blocked_mask = mask.reshape(batch_size, q_blocks, block_q, kv_blocks, block_kv)
284256
return jnp.count_nonzero(blocked_mask, axis=(2, 4)).astype(jnp.int32)

src/MaxText/kernels/ragged_attention.py renamed to src/maxtext/kernels/attention/ragged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/kernels/splash_attention_kernel.py renamed to src/maxtext/kernels/attention/splash_attention_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: skip-file
22
from __future__ import annotations
33

4-
# Copyright 2023–2025 Google LLC
4+
# Copyright 2023–2026 Google LLC
55
#
66
# Licensed under the Apache License, Version 2.0 (the "License");
77
# you may not use this file except in compliance with the License.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414
"""Megablox kernel"""
1515

16-
from MaxText.kernels.megablox.ops import gmm
16+
from maxtext.kernels.megablox.ops import gmm
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)