Skip to content

Commit 352dd58

Browse files
Merge pull request #3093 from AI-Hypercomputer:aireen/fix_ungroup
PiperOrigin-RevId: 866061982
2 parents 863d2a3 + dadc275 commit 352dd58

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/MaxText/layers/attention_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@
6868
Q_LENGTH,
6969
Q_LENGTH_NO_EXP,
7070
)
71+
from MaxText.layers import nnx_wrappers
72+
from MaxText.layers.initializers import variable_to_logically_partitioned
73+
from MaxText.layers.quantizations import AqtQuantization as Quant
74+
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
7175
from maxtext.inference import page_manager
7276
from maxtext.inference.kvcache import KVQuant, KVTensor
7377
from maxtext.kernels.attention import jax_flash_attention
7478
from maxtext.kernels.attention.ragged_attention import ragged_gqa
7579
from maxtext.kernels.attention.ragged_attention import ragged_mha
76-
from MaxText.layers import nnx_wrappers
77-
from MaxText.layers.initializers import variable_to_logically_partitioned
78-
from MaxText.layers.quantizations import AqtQuantization as Quant
79-
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
8080
from maxtext.utils import max_utils
8181
import numpy as np
8282
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel

src/MaxText/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from MaxText import common_types as ctypes
3333
from MaxText.common_types import ShardMode
3434
from MaxText.sharding import maybe_shard_with_logical, create_sharding
35-
from maxtext.kernels import megablox as mblx
3635
from MaxText.sharding import logical_to_mesh_axes
3736
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
3837
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
38+
from maxtext.kernels import megablox as mblx
3939
from maxtext.utils import max_logging
4040
from maxtext.utils import max_utils
4141
import numpy as np

0 commit comments

Comments
 (0)