Skip to content

Commit 0299786

Browse files
committed
Fixed pylink errors
1 parent 58038ec commit 0299786

35 files changed

Lines changed: 978 additions & 1344 deletions

src/maxdiffusion/__init__.py

Lines changed: 196 additions & 182 deletions
Large diffs are not rendered by default.

src/maxdiffusion/kernels/splash_attention/base.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
MaskInfo = mask_info_lib.MaskInfo
2626

2727

28-
DEFAULT_MASK_VALUE: Final[float] = -0.7 * float(
29-
np.finfo(np.dtype("float32")).max
30-
)
28+
DEFAULT_MASK_VALUE: Final[float] = -0.7 * float(np.finfo(np.dtype("float32")).max)
3129

3230

3331
class SegmentIds(NamedTuple):
@@ -55,9 +53,7 @@ class SegmentIds(NamedTuple):
5553

5654

5755
# Return type of SplashAttention function that implements the custom vjp rule.
58-
SplashCustomReturnType: TypeAlias = (
59-
jax.Array | tuple[jax.Array, dict[str, jax.Array]]
60-
)
56+
SplashCustomReturnType: TypeAlias = jax.Array | tuple[jax.Array, dict[str, jax.Array]]
6157

6258
SplashResidualsType = tuple[
6359
jax.Array, # q
@@ -85,9 +81,7 @@ def _attention_reference_impl(
8581
logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32))
8682

8783
if segment_ids is not None:
88-
mask = jnp.logical_and(
89-
mask, segment_ids.q[:, None] == segment_ids.kv[None, :]
90-
)
84+
mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :])
9185

9286
if attn_logits_soft_cap is not None:
9387
logits = jnp.tanh(logits / attn_logits_soft_cap)
@@ -126,9 +120,7 @@ def _attention_reference_custom_bwd(
126120
backward_impl: str = "vanilla",
127121
attn_logits_soft_cap: float | None = None,
128122
) -> tuple[jax.Array, jax.Array, jax.Array, None, None, jax.Array | None]:
129-
uncapped_logits = jnp.einsum(
130-
"qc,kc->qk", q, k, preferred_element_type=jnp.float32
131-
)
123+
uncapped_logits = jnp.einsum("qc,kc->qk", q, k, preferred_element_type=jnp.float32)
132124

133125
if attn_logits_soft_cap is not None:
134126
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
@@ -137,9 +129,7 @@ def _attention_reference_custom_bwd(
137129
logits = uncapped_logits
138130

139131
if segment_ids is not None:
140-
mask = jnp.logical_and(
141-
mask, segment_ids.q[:, None] == segment_ids.kv[None, :]
142-
)
132+
mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :])
143133
logits = jnp.where(mask, logits, mask_value)
144134

145135
p = jnp.exp(logits - logsumexp[..., None])
@@ -165,10 +155,7 @@ def _attention_reference_custom_bwd(
165155
dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype)
166156
dsinks = None
167157
if sinks is not None:
168-
sinks_exp = -jnp.exp(
169-
sinks[..., None, None].astype(jnp.float32)
170-
- logsumexp[..., None].astype(jnp.float32)
171-
)
158+
sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) - logsumexp[..., None].astype(jnp.float32))
172159
dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2))
173160
return dq, dk, dv, None, None, dsinks
174161

@@ -229,9 +216,7 @@ def attention_reference(
229216
return out
230217

231218

232-
@functools.partial(
233-
jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"]
234-
)
219+
@functools.partial(jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"])
235220
def attention_reference_vjp(
236221
do,
237222
q,
@@ -269,9 +254,7 @@ def attention_reference_vjp(
269254
k = jnp.repeat(k, head_multiplier, axis=0)
270255
v = jnp.repeat(v, head_multiplier, axis=0)
271256

272-
dq, dk, dv, _, _, dsinks = bwd(
273-
do, q, k, v, mask, segment_ids, sinks, o, logsumexp
274-
)
257+
dq, dk, dv, _, _, dsinks = bwd(do, q, k, v, mask, segment_ids, sinks, o, logsumexp)
275258

276259
if is_mqa:
277260
dk, dv = dk.sum(axis=0), dv.sum(axis=0)

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
_splash_attention_bwd = splash_kernel._splash_attention_bwd # pylint: disable=protected-access
4242

4343

44-
def _dynamic_slice_mask_info(
45-
mask_info: MaskInfo, kv_shard_idx: jax.Array, ring_size: int
46-
) -> MaskInfo:
44+
def _dynamic_slice_mask_info(mask_info: MaskInfo, kv_shard_idx: jax.Array, ring_size: int) -> MaskInfo:
4745
"""Slices MaskInfo for the current ring step."""
4846

4947
def slice_if_exists(arr: jax.Array | None):
@@ -83,9 +81,7 @@ def _ring_attention_forward(
8381
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
8482

8583
if q.shape[-1] != k.shape[-1]:
86-
raise NotImplementedError(
87-
"Queries and keys must have the same head dimension."
88-
)
84+
raise NotImplementedError("Queries and keys must have the same head dimension.")
8985

9086
if sinks is not None:
9187
raise NotImplementedError("Sinks aren't supportd yet.")
@@ -124,13 +120,11 @@ def _ring_attention_forward(
124120
l_init = jnp.zeros((o_shape[0], o_shape[1]), jnp.float32)
125121
m_init = jnp.full_like(l_init, mask_value, dtype=jnp.float32)
126122

127-
def body(carry, i: int)-> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]:
123+
def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]:
128124
m_prev, l_prev, o_prev, k_current, v_current, segment_ids_current = carry
129125

130126
current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size
131-
local_fwd_mask_info = _dynamic_slice_mask_info(
132-
fwd_mask_info, current_kv_shard_idx, ring_axis_size
133-
)
127+
local_fwd_mask_info = _dynamic_slice_mask_info(fwd_mask_info, current_kv_shard_idx, ring_axis_size)
134128
k_next = shift(k_current)
135129
v_next = shift(v_current)
136130

@@ -225,9 +219,7 @@ def body(carry, i: int):
225219
v_next = shift(v_current)
226220

227221
current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size
228-
local_dkv_mask_info = _dynamic_slice_mask_info(
229-
dkv_mask_info, current_kv_shard_idx, ring_axis_size
230-
)
222+
local_dkv_mask_info = _dynamic_slice_mask_info(dkv_mask_info, current_kv_shard_idx, ring_axis_size)
231223
if segment_ids is not None and rotate_segment_ids:
232224
kv_segment_ids_next = shift(segment_ids_current.kv)
233225
segment_ids_next = SegmentIds(segment_ids.q, kv_segment_ids_next)
@@ -255,9 +247,7 @@ def body(carry, i: int):
255247
fwd_mask_sparsity=fwd_mask_sparsity,
256248
dkv_mask_sparsity=dkv_mask_sparsity,
257249
)
258-
_, _, dq_i, dk_i, dv_i, _, dsinks, _ = attn_bwd(
259-
res=residuals_for_chunk, do=do
260-
)
250+
_, _, dq_i, dk_i, dv_i, _, dsinks, _ = attn_bwd(res=residuals_for_chunk, do=do)
261251
dv_next = shift(dv_accum + dv_i.astype(dv_accum.dtype))
262252
dk_next = shift(dk_accum + dk_i.astype(dk_accum.dtype))
263253
dq_accum = dq_accum + dq_i.astype(dq_accum.dtype)
@@ -394,7 +384,7 @@ def _ring_attention_custom(
394384
dkv_mask_sparsity: float,
395385
save_residuals: bool,
396386
ring_axis: str,
397-
rotate_segment_ids: bool ,
387+
rotate_segment_ids: bool,
398388
) -> SplashCustomReturnType:
399389
"""Performs ring attention with a custom VJP.
400390
@@ -561,7 +551,7 @@ def __init__(
561551
fwd_mask_info: MaskInfo,
562552
dkv_mask_info: MaskInfo | None,
563553
ring_axis: str,
564-
rotate_segment_ids: bool ,
554+
rotate_segment_ids: bool,
565555
**kwargs,
566556
):
567557
self.fwd_mask_info = fwd_mask_info
@@ -590,6 +580,7 @@ def manual_sharding_spec(self):
590580
"""
591581

592582
spec = jax.sharding.PartitionSpec(self.ring_axis)
583+
593584
def _resolve_spec(x):
594585
return spec if x is not None else None
595586

@@ -618,11 +609,7 @@ def tree_flatten(self):
618609
@classmethod
619610
def tree_unflatten(cls, aux_data, children):
620611
fwd_mask_info, dkv_mask_info = children
621-
dkv_mask_info = (
622-
mask_info_lib.MaskInfo(*dkv_mask_info)
623-
if dkv_mask_info is not None
624-
else None
625-
)
612+
dkv_mask_info = mask_info_lib.MaskInfo(*dkv_mask_info) if dkv_mask_info is not None else None
626613
return cls(
627614
mask_info_lib.MaskInfo(*fwd_mask_info),
628615
dkv_mask_info,
@@ -674,9 +661,7 @@ def make_ring_attention(
674661
mask = mask_lib.NumpyMask(mask)
675662

676663
if not isinstance(mask, (mask_lib.NumpyMask, mask_lib.FullMask)):
677-
raise NotImplementedError(
678-
f"Only NumpyMask and FullMask are supported, but got {type(mask)}."
679-
)
664+
raise NotImplementedError(f"Only NumpyMask and FullMask are supported, but got {type(mask)}.")
680665

681666
if config is None:
682667
config = SplashConfig.get_default()

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel_test.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ def test_ring_attention(
6767
mask_type,
6868
):
6969
if len(jax.devices()) < ring_size:
70-
self.skipTest(
71-
f"This test requires {ring_size} devices, but has only"
72-
f" {len(jax.devices())} devices available."
73-
)
70+
self.skipTest(f"This test requires {ring_size} devices, but has only" f" {len(jax.devices())} devices available.")
7471

7572
# Mesh Creation and Input Generation
7673
ring_axis = "ring"
@@ -85,14 +82,8 @@ def test_ring_attention(
8582
k = random.normal(k2, (seq_len, head_dim), dtype=dtype) * scale
8683
v = random.normal(k3, (seq_len, head_dim), dtype=dtype) * scale
8784
else:
88-
k = (
89-
random.normal(k2, (num_heads, seq_len, head_dim), dtype=dtype)
90-
* scale
91-
)
92-
v = (
93-
random.normal(k3, (num_heads, seq_len, head_dim), dtype=dtype)
94-
* scale
95-
)
85+
k = random.normal(k2, (num_heads, seq_len, head_dim), dtype=dtype) * scale
86+
v = random.normal(k3, (num_heads, seq_len, head_dim), dtype=dtype) * scale
9687
do = random.normal(k4, q.shape, dtype=dtype) * scale
9788

9889
if mask_type == "CAUSAL":
@@ -112,7 +103,6 @@ def test_ring_attention(
112103
q_spec = P(None, ring_axis, None)
113104
kv_spec = P(ring_axis, None) if is_mqa else q_spec
114105

115-
116106
splash_config = splash.SplashConfig.get_default()
117107
splash_config = dataclasses.replace(
118108
splash_config,
@@ -159,9 +149,7 @@ def ring_attn(ring_kernel, q, k, v, segment_ids):
159149

160150
with self.subTest("bwd"):
161151
out, out_vjp = jax.vjp(ring_attn, ring_kernel, q, k, v, segment_ids)
162-
out_ref, out_vjp_ref = jax.vjp(
163-
ring_attn_ref, q, k, v, mask[:, :], segment_ids
164-
)
152+
out_ref, out_vjp_ref = jax.vjp(ring_attn_ref, q, k, v, mask[:, :], segment_ids)
165153
self._assert_allclose(out, out_ref, rtol=5e-3, atol=3e-3)
166154

167155
_, dq, dk, dv, _ = out_vjp(do)

0 commit comments

Comments
 (0)