Skip to content

Commit 525b399

Browse files
Merge pull request #3032 from AI-Hypercomputer:move-maxtext-kernels
PiperOrigin-RevId: 865669628
2 parents 494b3b0 + 05700d5 commit 525b399

14 files changed

Lines changed: 169 additions & 80 deletions

src/MaxText/layers/attention_op.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,15 @@
6868
Q_LENGTH,
6969
Q_LENGTH_NO_EXP,
7070
)
71-
72-
from MaxText.kernels import jax_flash_attention
73-
from MaxText.kernels.ragged_attention import ragged_gqa
74-
from MaxText.kernels.ragged_attention import ragged_mha
71+
from maxtext.inference import page_manager
72+
from maxtext.inference.kvcache import KVQuant, KVTensor
73+
from maxtext.kernels.attention import jax_flash_attention
74+
from maxtext.kernels.attention.ragged_attention import ragged_gqa
75+
from maxtext.kernels.attention.ragged_attention import ragged_mha
7576
from MaxText.layers import nnx_wrappers
7677
from MaxText.layers.initializers import variable_to_logically_partitioned
7778
from MaxText.layers.quantizations import AqtQuantization as Quant
7879
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
79-
from maxtext.inference import page_manager
80-
from maxtext.inference.kvcache import KVQuant, KVTensor
8180
from maxtext.utils import max_utils
8281
import numpy as np
8382
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import jax
2323
import jax.numpy as jnp
24-
from MaxText.kernels import megablox
25-
from MaxText.kernels import sort_activations
24+
from maxtext.kernels import megablox
25+
from maxtext.kernels import sort_activations
2626
from MaxText.layers import attention_op
2727
from MaxText.layers import quantizations
2828

src/MaxText/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from MaxText import common_types as ctypes
3333
from MaxText.common_types import ShardMode
3434
from MaxText.sharding import maybe_shard_with_logical, create_sharding
35-
from MaxText.kernels import megablox as mblx
35+
from maxtext.kernels import megablox as mblx
3636
from MaxText.sharding import logical_to_mesh_axes
3737
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
3838
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/kernels/jax_flash_attention.py renamed to src/maxtext/kernels/attention/jax_flash_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717

1818
import jax
1919
import jax.numpy as jnp
20-
from MaxText.kernels import splash_attention_kernel
20+
from maxtext.kernels.attention import splash_attention_kernel
2121

2222
SegmentIds = splash_attention_kernel.SegmentIds
2323

src/MaxText/kernels/ragged_attention.py renamed to src/maxtext/kernels/attention/ragged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/MaxText/kernels/splash_attention_kernel.py renamed to src/maxtext/kernels/attention/splash_attention_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: skip-file
22
from __future__ import annotations
33

4-
# Copyright 2023–2025 Google LLC
4+
# Copyright 2023–2026 Google LLC
55
#
66
# Licensed under the Apache License, Version 2.0 (the "License");
77
# you may not use this file except in compliance with the License.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414
"""Megablox kernel"""
1515

16-
from MaxText.kernels.megablox.ops import gmm
16+
from maxtext.kernels.megablox.ops import gmm
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)