Skip to content

Commit 77cfb55

Browse files
Merge pull request #3073 from AI-Hypercomputer:mohit/fix_lint
PiperOrigin-RevId: 865048339
2 parents 1a44692 + 96d181e commit 77cfb55

4 files changed

Lines changed: 26 additions & 65 deletions

File tree

src/MaxText/kernels/sort_activations.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def _unroute_fwd(
7373
)
7474

7575

76-
def _unroute_bwd(
77-
use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array
78-
) -> tuple[jax.Array, None]:
76+
def _unroute_bwd(use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array) -> tuple[jax.Array, None]:
7977
selected_experts = residuals
8078
return _route_impl(grads, selected_experts, use_custom_mosaic_kernel), None
8179

@@ -90,8 +88,7 @@ def _route_impl(
9088
) -> jax.Array:
9189
"""Gather `tokens` according to `selected_experts`."""
9290
assert (
93-
tokens.shape[0] == selected_experts.shape[0]
94-
and selected_experts.ndim == 2
91+
tokens.shape[0] == selected_experts.shape[0] and selected_experts.ndim == 2
9592
), f"{tokens.shape=}, {selected_experts.shape=}"
9693
if use_custom_mosaic_kernel:
9794
raise NotImplementedError("Custom Mosaic kernel not implemented.")
@@ -104,10 +101,8 @@ def _unroute_impl(
104101
selected_experts: jax.Array,
105102
use_custom_mosaic_kernel: bool,
106103
) -> jax.Array:
107-
assert (
108-
tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1]
109-
and selected_experts.ndim == 2
110-
)
104+
"""Reverse the routing operation, restoring tokens to their original order."""
105+
assert tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1] and selected_experts.ndim == 2
111106
inds = jnp.argsort(jnp.argsort(jnp.ravel(selected_experts)))
112107
return jnp.sum(
113108
jnp.reshape(
@@ -118,9 +113,7 @@ def _unroute_impl(
118113
)
119114

120115

121-
def _sort_impl(
122-
tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool
123-
) -> jax.Array:
116+
def _sort_impl(tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool) -> jax.Array:
124117
if use_custom_mosaic_kernel:
125118
raise NotImplementedError("Custom Mosaic kernel not implemented.")
126119
else:

src/MaxText/layers/deepseek.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from jax.sharding import Mesh
2525
from MaxText.common_types import Config
2626
from MaxText.common_types import MODEL_MODE_PREFILL
27-
from maxtext.inference import page_manager
2827
from MaxText.layers import attention_mla
2928
from MaxText.layers import deepseek_batchsplit
3029
from MaxText.layers import initializers
@@ -37,7 +36,7 @@
3736
from MaxText.sharding import create_sharding
3837
from MaxText.sharding import maybe_shard_with_logical
3938
from maxtext.utils import max_utils
40-
39+
from maxtext.inference import page_manager
4140

4241
# -----------------------------------------
4342
# The Decoder Layer for DeepSeek v3

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,9 @@ def fetch_weights(params, dtype):
5858
params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["wo"],
5959
),
6060
(
61-
params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_0"][
62-
"kernel"
63-
],
64-
params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_1"][
65-
"kernel"
66-
],
67-
params["DeepSeekMoeBlock_0"]["shared_experts"]["wo"][
68-
"kernel"
69-
],
61+
params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_0"]["kernel"],
62+
params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_1"]["kernel"],
63+
params["DeepSeekMoeBlock_0"]["shared_experts"]["wo"]["kernel"],
7064
),
7165
),
7266
),
@@ -201,11 +195,11 @@ def batch_split_schedule(
201195

202196

203197
def staggered_call(fn, xs):
204-
for i in range(len(xs)):
198+
for i, x in enumerate(xs):
205199
if i == len(xs) - 1:
206-
xs[i] = fn(xs[i])
200+
xs[i] = fn(x)
207201
else:
208-
xs[i], xs[i + 1] = jax.lax.optimization_barrier((fn(xs[i]), xs[i + 1]))
202+
xs[i], xs[i + 1] = jax.lax.optimization_barrier((fn(x), xs[i + 1]))
209203
return xs
210204

211205

@@ -215,9 +209,7 @@ def with_data_parallel_constraint(x, mesh):
215209
None,
216210
None,
217211
)
218-
return jax.lax.with_sharding_constraint(
219-
x, jax.NamedSharding(mesh, activation_pspec)
220-
)
212+
return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec))
221213

222214

223215
def dot(x, y, axes=1):
@@ -290,9 +282,7 @@ def fn(args):
290282
dtype=dtype,
291283
)
292284

293-
return staggered_call(
294-
fn, list(zip(inputs, decoder_segment_ids, decoder_positions))
295-
)
285+
return staggered_call(fn, list(zip(inputs, decoder_segment_ids, decoder_positions)))
296286

297287

298288
def mla(
@@ -484,9 +474,7 @@ def kv_projection(
484474
)
485475

486476

487-
def get_key_value(
488-
low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads
489-
):
477+
def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads):
490478
"""Gets key and value from compressed KV latent vector and key rope."""
491479
kv_out = dot(low_rank_main, wkv_b_weights)
492480

@@ -541,20 +529,13 @@ def yarn(
541529
half_dim = embedding_dims // 2
542530
# Compute base frequencies for each (even-indexed) dimension.
543531
# (Note: We use jnp.arange with float32 for precision.)
544-
freqs = 1.0 / (
545-
rope_theta
546-
** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / embedding_dims)
547-
)
532+
freqs = 1.0 / (rope_theta ** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / embedding_dims))
548533

549534
low = (
550-
embedding_dims
551-
* math.log(original_max_position_embeddings / (beta_fast * 2 * math.pi))
552-
/ (2 * math.log(rope_theta))
535+
embedding_dims * math.log(original_max_position_embeddings / (beta_fast * 2 * math.pi)) / (2 * math.log(rope_theta))
553536
)
554537
high = (
555-
embedding_dims
556-
* math.log(original_max_position_embeddings / (beta_slow * 2 * math.pi))
557-
/ (2 * math.log(rope_theta))
538+
embedding_dims * math.log(original_max_position_embeddings / (beta_slow * 2 * math.pi)) / (2 * math.log(rope_theta))
558539
)
559540
low = max(math.floor(low), 0)
560541
high = min(math.ceil(high), embedding_dims - 1)
@@ -565,9 +546,7 @@ def yarn(
565546
freqs = freqs / rope_factor * (1 - smooth) + freqs * smooth
566547

567548
# Precompute frequencies for all positions by taking the outer product.
568-
t = jnp.arange(
569-
max_position_embeddings, dtype=jnp.float32
570-
) # shape [max_position_embeddings]
549+
t = jnp.arange(max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings]
571550
# This gives a [max_position_embeddings, half_dim] tensor with rows as time steps.
572551
freqs = jnp.outer(t, freqs)
573552

@@ -578,9 +557,7 @@ def yarn(
578557
freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim]
579558
freqs = jnp.repeat(freqs, 2, axis=-1) # shape: [B, S, 1, embedding_dims]
580559
# inputs @ mask: [B, S, N, embedding_dims] @ [embedding_dims, embedding_dims] -> [B, S, N, embedding_dims]
581-
output = inputs * jnp.cos(freqs) + jnp.matmul(
582-
inputs, pairwise_swap_and_negate_mask
583-
) * jnp.sin(freqs)
560+
output = inputs * jnp.cos(freqs) + jnp.matmul(inputs, pairwise_swap_and_negate_mask) * jnp.sin(freqs)
584561
return output.astype(fprop_dtype)
585562

586563

@@ -671,9 +648,7 @@ def route(
671648
# Communicate local results across the expert axis.
672649
x = jax.lax.all_gather(x, axis_name=expert_axis_name, tiled=True)
673650
weights = jax.lax.all_gather(weights, axis_name=expert_axis_name, tiled=True)
674-
selected_experts = jax.lax.all_gather(
675-
selected_experts, axis_name=expert_axis_name, tiled=True
676-
)
651+
selected_experts = jax.lax.all_gather(selected_experts, axis_name=expert_axis_name, tiled=True)
677652
group_sizes = jax.lax.psum(group_sizes, axis_name=expert_axis_name)
678653

679654
# Sort the gathered tokens and weights.
@@ -703,14 +678,10 @@ def unroute(
703678
)
704679

705680
# Sum across expert shards.
706-
return jax.lax.psum_scatter(
707-
x, expert_axis_name, scatter_dimension=0, tiled=True
708-
)
681+
return jax.lax.psum_scatter(x, expert_axis_name, scatter_dimension=0, tiled=True)
709682

710683

711-
def compute(
712-
x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size, dtype
713-
):
684+
def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size, dtype):
714685
"""Processes routed tokens through the MLP."""
715686
gmm_fn = functools.partial(
716687
megablox.gmm,
@@ -747,9 +718,7 @@ def route_compute_unroute(
747718

748719
def route_fn(inputs):
749720
# Shared expert.
750-
y = dot(
751-
jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo
752-
)
721+
y = dot(jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo)
753722

754723
inputs = jnp.reshape(inputs, (-1, inputs.shape[-1]))
755724
selected_experts, weights, group_sizes = expert_selection(

src/MaxText/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -951,9 +951,9 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
951951
# fusion (max reduce over contracting dimension).
952952
tiling = (tiling[0], k, tiling[2])
953953

954-
is_tpu = (self.mesh.devices.flat[0] == "tpu")
954+
is_tpu = self.mesh.devices.flat[0] == "tpu"
955955
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
956-
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else '0'
956+
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0"
957957
with set_xla_metadata(
958958
ragged_dot_tiling=",".join([str(t) for t in tiling]),
959959
mosaic_fusion_group=mosaic_group_id,

0 commit comments

Comments
 (0)