Skip to content

Commit 58038ec

Browse files
author
Elisa Tsai
committed
ruff fix
1 parent 6cbbf28 commit 58038ec

5 files changed

Lines changed: 48 additions & 45 deletions

File tree

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ class RingSplashAttentionKernel:
544544
"""Implements Ring Attention using SplashAttention for sequence parallelism.
545545
546546
This kernel computes global attention by keeping Keys and Values distributed
547-
across the `ring_axis`. Instead of gathering full sequences, it rotates K/V
547+
across the `ring_axis`. Instead of gathering full sequences, it rotates K/V
548548
shards between devices and accumulates results incrementally. This allows
549549
processing sequence lengths that exceed single-device memory limits.
550550
@@ -590,7 +590,8 @@ def manual_sharding_spec(self):
590590
"""
591591

592592
spec = jax.sharding.PartitionSpec(self.ring_axis)
593-
_resolve_spec = lambda x: spec if x is not None else None
593+
def _resolve_spec(x):
594+
return spec if x is not None else None
594595

595596
mask_info_specs = MaskInfo( # pytype: disable=wrong-arg-types
596597
mask_next=_resolve_spec(self.fwd_mask_info.mask_next),

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def get_default(cls):
196196
)
197197

198198

199-
to_i32 = lambda x: x.astype(jnp.int32)
199+
def to_i32(x):
200+
return x.astype(jnp.int32)
200201

201202

202203
def _apply_mask_and_soft_cap(
@@ -1471,7 +1472,8 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
14711472
return next_m, 0, 0
14721473

14731474
else:
1474-
unravel = lambda f: lambda j, h, i, *_: f(h, i, j)
1475+
def unravel(f):
1476+
return lambda j, h, i, *_: f(h, i, j)
14751477
grid = (kv_steps, num_q_heads, q_steps)
14761478

14771479
def mask_index_map(j, h, i, rows_ref, cols_ref, mask_next_ref=None, *_):
@@ -1656,15 +1658,15 @@ def create_dkv_index_map(h, i, j, *_):
16561658
)
16571659
metadata = {
16581660
"xprof_metadata": json.dumps(
1659-
dict(
1660-
block_q_dkv=bq,
1661-
block_kv_dkv=bkv,
1662-
block_kv_dkv_compute=bkv_compute,
1663-
q_layout=config.q_layout,
1664-
k_layout=config.k_layout,
1665-
v_layout=config.v_layout,
1666-
use_experimental_scheduler=config.use_experimental_scheduler,
1667-
),
1661+
{
1662+
"block_q_dkv": bq,
1663+
"block_kv_dkv": bkv,
1664+
"block_kv_dkv_compute": bkv_compute,
1665+
"q_layout": config.q_layout,
1666+
"k_layout": config.k_layout,
1667+
"v_layout": config.v_layout,
1668+
"use_experimental_scheduler": config.use_experimental_scheduler,
1669+
},
16681670
)
16691671
}
16701672
args = [
@@ -1970,7 +1972,8 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
19701972
if len(sharding.spec) != 1:
19711973
raise ValueError("Only q sequence sharding is supported.")
19721974

1973-
_resolve_spec = lambda x: sharding.spec if x is not None else None
1975+
def _resolve_spec(x):
1976+
return sharding.spec if x is not None else None
19741977
mask_info_specs = MaskInfo( # pytype: disable=wrong-arg-types
19751978
mask_next=_resolve_spec(self.fwd_mask_info.mask_next),
19761979
active_rows=_resolve_spec(self.fwd_mask_info.active_rows),
@@ -2115,15 +2118,15 @@ def process_mask_shard(mask):
21152118

21162119
return fwd_mask_info, dkv_mask_info
21172120

2118-
kwargs = dict(
2119-
config=config,
2120-
is_mqa=is_mqa,
2121-
save_residuals=save_residuals,
2122-
mask_value=mask_value,
2123-
mask_function=None,
2124-
fwd_mask_sparsity=1.0,
2125-
dkv_mask_sparsity=1.0,
2126-
)
2121+
kwargs = {
2122+
"config": config,
2123+
"is_mqa": is_mqa,
2124+
"save_residuals": save_residuals,
2125+
"mask_value": mask_value,
2126+
"mask_function": None,
2127+
"fwd_mask_sparsity": 1.0,
2128+
"dkv_mask_sparsity": 1.0,
2129+
}
21272130

21282131
# If the input mask is replicated we don't need to call shard_map.
21292132
if mask_spec is None:

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def block_sizes_strategy(
279279
q_layout = draw(hps.sampled_from(splash.QKVLayout))
280280
k_layout = draw(hps.sampled_from(splash.QKVLayout))
281281
v_layout = draw(hps.sampled_from(splash.QKVLayout))
282-
layouts = dict(q_layout=q_layout, k_layout=k_layout, v_layout=v_layout)
282+
layouts = {"q_layout": q_layout, "k_layout": k_layout, "v_layout": v_layout}
283283
q_valid_block_shapes = [bs for bs in all_block_shapes if bs <= q_seq_len]
284284
kv_valid_block_shapes = [bs for bs in all_block_shapes if bs <= kv_seq_len]
285285
bq, bkv = (
@@ -494,16 +494,16 @@ def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
494494
sinks,
495495
)
496496

497-
lse_tol = dict(atol=1e-3, rtol=3e-3)
498-
max_logits_tol = dict(atol=1e-3, rtol=4e-3)
497+
lse_tol = {"atol": 1e-3, "rtol": 3e-3}
498+
max_logits_tol = {"atol": 1e-3, "rtol": 4e-3}
499499
if use_sinks:
500-
o_tol = dict(atol=8e-2, rtol=1e-1)
500+
o_tol = {"atol": 8e-2, "rtol": 1e-1}
501501
lse_tol['rtol'] = 6e-2
502502
elif (use_base2_exp or use_max_logit_estimate is not None
503503
or not fuse_reciprocal):
504-
o_tol = dict(atol=8e-3, rtol=3e-3)
504+
o_tol = {"atol": 8e-3, "rtol": 3e-3}
505505
else:
506-
o_tol = dict(atol=4e-3, rtol=3e-3)
506+
o_tol = {"atol": 4e-3, "rtol": 3e-3}
507507

508508
self._assert_allclose(o, o_ref, **o_tol)
509509
self._assert_allclose(stats["logsumexp"],
@@ -598,12 +598,12 @@ def test_splash_attention_bwd(
598598
attn_logits_soft_cap=attn_logits_soft_cap,
599599
)
600600
if use_sinks:
601-
o_tol = dict(atol=1e-2, rtol=1e-1)
601+
o_tol = {"atol": 1e-2, "rtol": 1e-1}
602602
elif (use_base2_exp or use_max_logit_estimate is not None
603603
or not fuse_reciprocal):
604-
o_tol = dict(atol=8e-3, rtol=1e-2)
604+
o_tol = {"atol": 8e-3, "rtol": 1e-2}
605605
else:
606-
o_tol = dict(atol=4e-3, rtol=3e-3)
606+
o_tol = {"atol": 4e-3, "rtol": 3e-3}
607607
self._assert_allclose(o, o_ref, **o_tol)
608608

609609
dq, dk, dv, _, dsinks = attn_vjp(do)

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,8 @@ def _process_mask(
521521
if return_dynamic_grid:
522522
# Pad each slice to the largest number of active blocks in any shard.
523523
max_size = max(num_active_blocks)
524-
pad_slice = lambda arr: np.pad(
525-
arr, (0, max_size - arr.shape[0]), mode='constant', constant_values=-1
526-
)
524+
def pad_slice(arr):
525+
return np.pad(arr, (0, max_size - arr.shape[0]), mode='constant', constant_values=-1)
527526
active_rows_slices = list(map(pad_slice, active_rows_slices))
528527
active_cols_slices = list(map(pad_slice, active_cols_slices))
529528
mask_next_slices = list(map(pad_slice, mask_next_slices))

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,16 +1362,16 @@ def test_two_qseq_shards_causal_local_stacked(self):
13621362
self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv)
13631363

13641364
@parameterized.named_parameters(
1365-
dict(
1366-
testcase_name="q_seq_shards_2",
1367-
q_seq_shards=2,
1368-
kv_seq_shards=1,
1369-
),
1370-
dict(
1371-
testcase_name="kv_seq_shards_2",
1372-
q_seq_shards=1,
1373-
kv_seq_shards=2,
1374-
),
1365+
{
1366+
"testcase_name": "q_seq_shards_2",
1367+
"q_seq_shards": 2,
1368+
"kv_seq_shards": 1,
1369+
},
1370+
{
1371+
"testcase_name": "kv_seq_shards_2",
1372+
"q_seq_shards": 1,
1373+
"kv_seq_shards": 2,
1374+
},
13751375
)
13761376
def test_two_shards_local_wide_local_narrow_stacked(
13771377
self, q_seq_shards, kv_seq_shards

0 commit comments

Comments
 (0)