|
41 | 41 | _splash_attention_bwd = splash_kernel._splash_attention_bwd # pylint: disable=protected-access |
42 | 42 |
|
43 | 43 |
|
44 | | -def _dynamic_slice_mask_info( |
45 | | - mask_info: MaskInfo, kv_shard_idx: jax.Array, ring_size: int |
46 | | -) -> MaskInfo: |
| 44 | +def _dynamic_slice_mask_info(mask_info: MaskInfo, kv_shard_idx: jax.Array, ring_size: int) -> MaskInfo: |
47 | 45 | """Slices MaskInfo for the current ring step.""" |
48 | 46 |
|
49 | 47 | def slice_if_exists(arr: jax.Array | None): |
@@ -83,9 +81,7 @@ def _ring_attention_forward( |
83 | 81 | ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: |
84 | 82 |
|
85 | 83 | if q.shape[-1] != k.shape[-1]: |
86 | | - raise NotImplementedError( |
87 | | - "Queries and keys must have the same head dimension." |
88 | | - ) |
| 84 | + raise NotImplementedError("Queries and keys must have the same head dimension.") |
89 | 85 |
|
90 | 86 | if sinks is not None: |
91 | 87 | raise NotImplementedError("Sinks aren't supportd yet.") |
@@ -124,13 +120,11 @@ def _ring_attention_forward( |
124 | 120 | l_init = jnp.zeros((o_shape[0], o_shape[1]), jnp.float32) |
125 | 121 | m_init = jnp.full_like(l_init, mask_value, dtype=jnp.float32) |
126 | 122 |
|
127 | | - def body(carry, i: int)-> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]: |
| 123 | + def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]: |
128 | 124 | m_prev, l_prev, o_prev, k_current, v_current, segment_ids_current = carry |
129 | 125 |
|
130 | 126 | current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size |
131 | | - local_fwd_mask_info = _dynamic_slice_mask_info( |
132 | | - fwd_mask_info, current_kv_shard_idx, ring_axis_size |
133 | | - ) |
| 127 | + local_fwd_mask_info = _dynamic_slice_mask_info(fwd_mask_info, current_kv_shard_idx, ring_axis_size) |
134 | 128 | k_next = shift(k_current) |
135 | 129 | v_next = shift(v_current) |
136 | 130 |
|
@@ -225,9 +219,7 @@ def body(carry, i: int): |
225 | 219 | v_next = shift(v_current) |
226 | 220 |
|
227 | 221 | current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size |
228 | | - local_dkv_mask_info = _dynamic_slice_mask_info( |
229 | | - dkv_mask_info, current_kv_shard_idx, ring_axis_size |
230 | | - ) |
| 222 | + local_dkv_mask_info = _dynamic_slice_mask_info(dkv_mask_info, current_kv_shard_idx, ring_axis_size) |
231 | 223 | if segment_ids is not None and rotate_segment_ids: |
232 | 224 | kv_segment_ids_next = shift(segment_ids_current.kv) |
233 | 225 | segment_ids_next = SegmentIds(segment_ids.q, kv_segment_ids_next) |
@@ -255,9 +247,7 @@ def body(carry, i: int): |
255 | 247 | fwd_mask_sparsity=fwd_mask_sparsity, |
256 | 248 | dkv_mask_sparsity=dkv_mask_sparsity, |
257 | 249 | ) |
258 | | - _, _, dq_i, dk_i, dv_i, _, dsinks, _ = attn_bwd( |
259 | | - res=residuals_for_chunk, do=do |
260 | | - ) |
| 250 | + _, _, dq_i, dk_i, dv_i, _, dsinks, _ = attn_bwd(res=residuals_for_chunk, do=do) |
261 | 251 | dv_next = shift(dv_accum + dv_i.astype(dv_accum.dtype)) |
262 | 252 | dk_next = shift(dk_accum + dk_i.astype(dk_accum.dtype)) |
263 | 253 | dq_accum = dq_accum + dq_i.astype(dq_accum.dtype) |
@@ -394,7 +384,7 @@ def _ring_attention_custom( |
394 | 384 | dkv_mask_sparsity: float, |
395 | 385 | save_residuals: bool, |
396 | 386 | ring_axis: str, |
397 | | - rotate_segment_ids: bool , |
| 387 | + rotate_segment_ids: bool, |
398 | 388 | ) -> SplashCustomReturnType: |
399 | 389 | """Performs ring attention with a custom VJP. |
400 | 390 |
|
@@ -561,7 +551,7 @@ def __init__( |
561 | 551 | fwd_mask_info: MaskInfo, |
562 | 552 | dkv_mask_info: MaskInfo | None, |
563 | 553 | ring_axis: str, |
564 | | - rotate_segment_ids: bool , |
| 554 | + rotate_segment_ids: bool, |
565 | 555 | **kwargs, |
566 | 556 | ): |
567 | 557 | self.fwd_mask_info = fwd_mask_info |
@@ -590,6 +580,7 @@ def manual_sharding_spec(self): |
590 | 580 | """ |
591 | 581 |
|
592 | 582 | spec = jax.sharding.PartitionSpec(self.ring_axis) |
| 583 | + |
593 | 584 | def _resolve_spec(x): |
594 | 585 | return spec if x is not None else None |
595 | 586 |
|
@@ -618,11 +609,7 @@ def tree_flatten(self): |
618 | 609 | @classmethod |
619 | 610 | def tree_unflatten(cls, aux_data, children): |
620 | 611 | fwd_mask_info, dkv_mask_info = children |
621 | | - dkv_mask_info = ( |
622 | | - mask_info_lib.MaskInfo(*dkv_mask_info) |
623 | | - if dkv_mask_info is not None |
624 | | - else None |
625 | | - ) |
| 612 | + dkv_mask_info = mask_info_lib.MaskInfo(*dkv_mask_info) if dkv_mask_info is not None else None |
626 | 613 | return cls( |
627 | 614 | mask_info_lib.MaskInfo(*fwd_mask_info), |
628 | 615 | dkv_mask_info, |
@@ -674,9 +661,7 @@ def make_ring_attention( |
674 | 661 | mask = mask_lib.NumpyMask(mask) |
675 | 662 |
|
676 | 663 | if not isinstance(mask, (mask_lib.NumpyMask, mask_lib.FullMask)): |
677 | | - raise NotImplementedError( |
678 | | - f"Only NumpyMask and FullMask are supported, but got {type(mask)}." |
679 | | - ) |
| 664 | + raise NotImplementedError(f"Only NumpyMask and FullMask are supported, but got {type(mask)}.") |
680 | 665 |
|
681 | 666 | if config is None: |
682 | 667 | config = SplashConfig.get_default() |
|
0 commit comments