Skip to content

Commit c5d9018

Browse files
committed
Add ring attention and unified sequence parallelism
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 4686b2e commit c5d9018

2 files changed

Lines changed: 202 additions & 3 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from typing import Sequence
1616
import jax
1717
import time
18-
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
1918
from maxdiffusion import pyconfig, max_logging, max_utils
19+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
20+
2021
from absl import app
2122
from maxdiffusion.utils import export_to_video
2223

src/maxdiffusion/models/attention_flax.py

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import flax.linen as nn
1919
from flax import nnx
2020
import jax
21-
from jax.sharding import PartitionSpec
21+
from jax.experimental.pjit import pjit
22+
from jax.sharding import NamedSharding, PartitionSpec
2223
import jax.numpy as jnp
2324
from jax.experimental import shard_map
2425
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
@@ -153,6 +154,198 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
153154
return tensor, kv_size, seq_len
154155

155156

157+
def _tpu_ring_flash_attention_v1(
158+
query: jax.Array,
159+
key: jax.Array,
160+
value: jax.Array,
161+
heads: int,
162+
mesh: Mesh,
163+
axis_names_q: AxisNames,
164+
axis_names_kv: AxisNames,
165+
flash_block_sizes: BlockSizes,
166+
dtype: jnp.dtype = jnp.float32,
167+
) -> jax.Array:
168+
"""TPU Ring Flash Attention with correct padding, transposition, and sharding."""
169+
from ringattention import ringattention
170+
from einops import rearrange
171+
172+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
173+
blockwise_kwargs = {
174+
# FullMask
175+
"causal_block_size": None,
176+
"deterministic": True,
177+
"dropout_rng": None,
178+
"attn_pdrop": 0.0,
179+
"policy": jax.checkpoint_policies.nothing_saveable,
180+
"dtype": dtype,
181+
"precision": None,
182+
"prevent_cse": True,
183+
}
184+
if flash_block_sizes:
185+
blockwise_kwargs["query_chunk_size"] = flash_block_sizes.block_q
186+
blockwise_kwargs["key_chunk_size"] = flash_block_sizes.block_kv
187+
else:
188+
blockwise_kwargs["query_chunk_size"] = min(max_block_size, query.shape[2])
189+
blockwise_kwargs["key_chunk_size"] = min(max_block_size, key.shape[2])
190+
191+
num_fsdp_shards = mesh.shape["fsdp"]
192+
query_padded, kv_size, query_seq_len = _reshape_data_for_flash(
193+
query, heads, blockwise_kwargs["query_chunk_size"], num_fsdp_shards
194+
)
195+
key_padded, _, _ = _reshape_data_for_flash(
196+
key, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
197+
)
198+
value_padded, _, _ = _reshape_data_for_flash(
199+
value, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
200+
)
201+
202+
# (b, s, h, d) shape expected
203+
query_t = rearrange(query_padded, 'b h s d -> b s h d')
204+
key_t = rearrange(key_padded, 'b h s d -> b s h d')
205+
value_t = rearrange(value_padded, 'b h s d -> b s h d')
206+
207+
q_axis_names_physical = nn.logical_to_mesh_axes(axis_names_q)
208+
transposed_spec = PartitionSpec(
209+
q_axis_names_physical[0], # BATCH -> 'data'
210+
q_axis_names_physical[2], # LENGTH -> 'fsdp'
211+
q_axis_names_physical[1], # HEAD -> 'tensor'
212+
q_axis_names_physical[3], # D_KV -> None
213+
)
214+
ring_axis_name = q_axis_names_physical[2]
215+
216+
ring_attention_sharded = shard_map.shard_map(
217+
functools.partial(
218+
ringattention,
219+
attn_bias=None,
220+
segment_ids=None,
221+
axis_name=ring_axis_name,
222+
float32_logits=True,
223+
cache_idx=None,
224+
blockwise_kwargs=blockwise_kwargs,
225+
),
226+
mesh=mesh,
227+
in_specs=(transposed_spec, transposed_spec, transposed_spec),
228+
# output is (b,s,h,d)
229+
out_specs=transposed_spec,
230+
check_rep=False,
231+
)
232+
233+
attention_output = ring_attention_sharded(query_t, key_t, value_t)
234+
235+
attention_output_t = rearrange(attention_output, 'b s h d -> b h s d')
236+
237+
# Unpad the output to get back original sequence
238+
attention_output_cropped = attention_output_t[:, :, :query_seq_len, :kv_size]
239+
240+
# Reshape to (b, s, h*d)
241+
final_output = _reshape_heads_to_head_dim(attention_output_cropped)
242+
243+
return final_output
244+
245+
def _tpu_ring_flash_attention(
246+
query: jax.Array,
247+
key: jax.Array,
248+
value: jax.Array,
249+
heads: int,
250+
mesh: Mesh,
251+
axis_names_q: AxisNames,
252+
axis_names_kv: AxisNames,
253+
flash_block_sizes: BlockSizes,
254+
dtype: jnp.dtype = jnp.float32,
255+
usp_degree: Optional[int] = 1,
256+
) -> jax.Array:
257+
"""TPU Ring/USP Flash Attention with correct padding, transposition, and sharding."""
258+
from ringattention import ringattention
259+
260+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
261+
blockwise_kwargs = {
262+
#FullMask
263+
"causal_block_size": None,
264+
"deterministic": True,
265+
"dropout_rng": None,
266+
"attn_pdrop": 0.0,
267+
"policy": jax.checkpoint_policies.nothing_saveable,
268+
"dtype": dtype,
269+
"precision": None,
270+
"prevent_cse": True,
271+
}
272+
if flash_block_sizes:
273+
blockwise_kwargs["query_chunk_size"] = flash_block_sizes.block_q
274+
blockwise_kwargs["key_chunk_size"] = flash_block_sizes.block_kv
275+
else:
276+
# Get seq_len from shape[2] of the original (b, h, s, d) tensor
277+
blockwise_kwargs["query_chunk_size"] = min(max_block_size, query.shape[2])
278+
blockwise_kwargs["key_chunk_size"] = min(max_block_size, key.shape[2])
279+
280+
# Pad sequence to be divisible by block size
281+
num_fsdp_shards = mesh.shape["fsdp"]
282+
query_padded, kv_size, query_seq_len = _reshape_data_for_flash(
283+
query, heads, blockwise_kwargs["query_chunk_size"], num_fsdp_shards
284+
)
285+
key_padded, _, _ = _reshape_data_for_flash(
286+
key, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
287+
)
288+
value_padded, _, _ = _reshape_data_for_flash(
289+
value, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
290+
)
291+
292+
num_fsdp_devices = mesh.shape["fsdp"]
293+
if num_fsdp_devices % usp_degree != 0:
294+
raise ValueError("fsdp axis size must be divisible by usp_degree")
295+
ring_degree = num_fsdp_devices // usp_degree
296+
logical_mesh_shape = (mesh.shape['data'], usp_degree, ring_degree, mesh.shape["tensor"])
297+
reshaped_devices = mesh.devices.reshape(logical_mesh_shape)
298+
logical_mesh = Mesh(reshaped_devices, ('data', 'usp', 'ring', 'tensor'))
299+
300+
def _kernel(q_padded, k_padded, v_padded):
301+
# All-to-All Swap NO-OP if usp_degree=1)
302+
q_swapped = jax.lax.all_to_all(q_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
303+
k_swapped = jax.lax.all_to_all(k_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
304+
v_swapped = jax.lax.all_to_all(v_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
305+
306+
q_t = rearrange(q_swapped, 'b h s d -> b s h d')
307+
k_t = rearrange(k_swapped, 'b h s d -> b s h d')
308+
v_t = rearrange(v_swapped, 'b h s d -> b s h d')
309+
310+
attn_out_t = ringattention(q_t, k_t, v_t, axis_name='ring',
311+
blockwise_kwargs=blockwise_kwargs,
312+
attn_bias=None,
313+
segment_ids=None,
314+
cache_idx=None,
315+
float32_logits=True,
316+
)
317+
return attn_out_t
318+
319+
320+
initial_spec = PartitionSpec('data', 'tensor', ('usp', 'ring'), None)
321+
322+
output_spec = PartitionSpec('data', 'ring', ('usp', 'tensor'), None)
323+
324+
sharded_attention_fn = shard_map.shard_map(
325+
_kernel,
326+
mesh=logical_mesh,
327+
in_specs=(initial_spec, initial_spec, initial_spec),
328+
out_specs=output_spec,
329+
check_rep=False
330+
)
331+
332+
attention_output_t = sharded_attention_fn(query_padded, key_padded, value_padded)
333+
334+
attention_output = rearrange(attention_output_t, 'b s h d -> b h s d')
335+
336+
# Back to original sharding using original physical mesh
337+
physical_spec = PartitionSpec('data', 'tensor', 'fsdp', None)
338+
attention_output_resharded = jax.lax.with_sharding_constraint(
339+
attention_output,
340+
NamedSharding(mesh, physical_spec)
341+
)
342+
343+
# Unpad sequence and reshape to head x head_dim
344+
attention_output_cropped = attention_output_resharded[:, :, :query_seq_len, :kv_size]
345+
final_output = _reshape_heads_to_head_dim(attention_output_cropped)
346+
347+
return final_output
348+
156349
def _tpu_flash_attention(
157350
query: jax.Array,
158351
key: jax.Array,
@@ -372,7 +565,7 @@ def _apply_attention(
372565
seq_len_idx = 1
373566
if query.ndim == 4:
374567
seq_len_idx = 2
375-
if attention_kernel == "flash":
568+
if attention_kernel in ["flash","ring_flash"]:
376569
can_use_flash_attention = (
377570
query.shape[seq_len_idx] >= flash_min_seq_length
378571
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -388,6 +581,11 @@ def _apply_attention(
388581
return _tpu_flash_attention(
389582
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
390583
)
584+
elif attention_kernel == "ring_flash":
585+
max_logging.log("USING RING ATTENTION")
586+
return _tpu_ring_flash_attention_v1(
587+
query, key , value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
588+
)
391589
elif attention_kernel == "cudnn_flash_te":
392590
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
393591
else:

0 commit comments

Comments
 (0)