Skip to content

Commit 46c1741

Browse files
committed
Add ring attention and unified sequence parallelism
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 4686b2e commit 46c1741

2 files changed

Lines changed: 220 additions & 3 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from typing import Sequence
1616
import jax
1717
import time
18-
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
1918
from maxdiffusion import pyconfig, max_logging, max_utils
19+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
20+
2021
from absl import app
2122
from maxdiffusion.utils import export_to_video
2223

src/maxdiffusion/models/attention_flax.py

Lines changed: 218 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import flax.linen as nn
1919
from flax import nnx
2020
import jax
21-
from jax.sharding import PartitionSpec
21+
from jax.experimental.pjit import pjit
22+
from jax.sharding import NamedSharding, PartitionSpec
2223
import jax.numpy as jnp
2324
from jax.experimental import shard_map
2425
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
@@ -153,6 +154,216 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
153154
return tensor, kv_size, seq_len
154155

155156

157+
def _tpu_ring_flash_attention_v1(
158+
query: jax.Array,
159+
key: jax.Array,
160+
value: jax.Array,
161+
heads: int,
162+
mesh: Mesh,
163+
axis_names_q: AxisNames,
164+
axis_names_kv: AxisNames,
165+
flash_block_sizes: BlockSizes,
166+
dtype: jnp.dtype = jnp.float32,
167+
) -> jax.Array:
168+
"""TPU Ring Flash Attention with correct padding, transposition, and sharding."""
169+
from ringattention import ringattention
170+
from einops import rearrange
171+
172+
# --- Step 1: Initialize Block Sizes ---
173+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
174+
blockwise_kwargs = {
175+
# CRITICAL: Ensures non-causal attention to match FullMask
176+
"causal_block_size": None,
177+
"deterministic": True,
178+
"dropout_rng": None,
179+
"attn_pdrop": 0.0,
180+
"policy": jax.checkpoint_policies.nothing_saveable,
181+
"dtype": dtype,
182+
"precision": None,
183+
"prevent_cse": True,
184+
}
185+
if flash_block_sizes:
186+
blockwise_kwargs["query_chunk_size"] = flash_block_sizes.block_q
187+
blockwise_kwargs["key_chunk_size"] = flash_block_sizes.block_kv
188+
else:
189+
# Get seq_len from shape[2] of the original (b, h, s, d) tensor
190+
blockwise_kwargs["query_chunk_size"] = min(max_block_size, query.shape[2])
191+
blockwise_kwargs["key_chunk_size"] = min(max_block_size, key.shape[2])
192+
193+
# --- Step 2: Pad and Preprocess Tensors (CRITICAL) ---
194+
num_fsdp_shards = mesh.shape["fsdp"]
195+
query_padded, kv_size, query_seq_len = _reshape_data_for_flash(
196+
query, heads, blockwise_kwargs["query_chunk_size"], num_fsdp_shards
197+
)
198+
key_padded, _, _ = _reshape_data_for_flash(
199+
key, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
200+
)
201+
value_padded, _, _ = _reshape_data_for_flash(
202+
value, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
203+
)
204+
205+
# Transpose the *padded* inputs to the (b, s, h, d) format the library expects
206+
query_t = rearrange(query_padded, 'b h s d -> b s h d')
207+
key_t = rearrange(key_padded, 'b h s d -> b s h d')
208+
value_t = rearrange(value_padded, 'b h s d -> b s h d')
209+
210+
# --- Step 3: Define Sharding and the Sharded Function ---
211+
# The PartitionSpec must match the transposed (b, s, h, d) layout.
212+
q_axis_names_physical = nn.logical_to_mesh_axes(axis_names_q)
213+
transposed_spec = PartitionSpec(
214+
q_axis_names_physical[0], # BATCH -> 'data'
215+
q_axis_names_physical[2], # LENGTH -> 'fsdp'
216+
q_axis_names_physical[1], # HEAD -> 'tensor'
217+
q_axis_names_physical[3], # D_KV -> None
218+
)
219+
ring_axis_name = q_axis_names_physical[2]
220+
221+
ring_attention_sharded = shard_map.shard_map(
222+
functools.partial(
223+
ringattention,
224+
attn_bias=None,
225+
segment_ids=None,
226+
axis_name=ring_axis_name,
227+
float32_logits=True,
228+
cache_idx=None,
229+
blockwise_kwargs=blockwise_kwargs,
230+
),
231+
mesh=mesh,
232+
in_specs=(transposed_spec, transposed_spec, transposed_spec),
233+
# The library's output is (b,s,h,d), so its sharding matches the transposed input
234+
out_specs=transposed_spec,
235+
check_rep=False,
236+
)
237+
238+
# --- Step 4: Execute and Post-process ---
239+
attention_output = ring_attention_sharded(query_t, key_t, value_t)
240+
241+
# Transpose the output from (b, s, h, d) back to the original (b, h, s, d) convention
242+
attention_output_t = rearrange(attention_output, 'b s h d -> b h s d')
243+
244+
# Unpad the output to match the original sequence and head dimension
245+
attention_output_cropped = attention_output_t[:, :, :query_seq_len, :kv_size]
246+
247+
# Reshape to the final (b, s, h*d) format, just like in _tpu_flash_attention
248+
final_output = _reshape_heads_to_head_dim(attention_output_cropped)
249+
250+
return final_output
251+
252+
def _tpu_ring_flash_attention(
253+
query: jax.Array,
254+
key: jax.Array,
255+
value: jax.Array,
256+
heads: int,
257+
mesh: Mesh,
258+
axis_names_q: AxisNames,
259+
axis_names_kv: AxisNames,
260+
flash_block_sizes: BlockSizes,
261+
dtype: jnp.dtype = jnp.float32,
262+
usp_degree: Optional[int] = 1,
263+
) -> jax.Array:
264+
"""TPU Ring Flash Attention with correct padding, transposition, and sharding."""
265+
from ringattention import ringattention
266+
267+
# --- Step 1: Initialize Block Sizes ---
268+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
269+
blockwise_kwargs = {
270+
# CRITICAL: Ensures non-causal attention to match FullMask
271+
"causal_block_size": None,
272+
"deterministic": True,
273+
"dropout_rng": None,
274+
"attn_pdrop": 0.0,
275+
"policy": jax.checkpoint_policies.nothing_saveable,
276+
"dtype": dtype,
277+
"precision": None,
278+
"prevent_cse": True,
279+
}
280+
if flash_block_sizes:
281+
blockwise_kwargs["query_chunk_size"] = flash_block_sizes.block_q
282+
blockwise_kwargs["key_chunk_size"] = flash_block_sizes.block_kv
283+
else:
284+
# Get seq_len from shape[2] of the original (b, h, s, d) tensor
285+
blockwise_kwargs["query_chunk_size"] = min(max_block_size, query.shape[2])
286+
blockwise_kwargs["key_chunk_size"] = min(max_block_size, key.shape[2])
287+
288+
# --- Step 2: Pad and Preprocess Tensors (CRITICAL) ---
289+
num_fsdp_shards = mesh.shape["fsdp"]
290+
query_padded, kv_size, query_seq_len = _reshape_data_for_flash(
291+
query, heads, blockwise_kwargs["query_chunk_size"], num_fsdp_shards
292+
)
293+
key_padded, _, _ = _reshape_data_for_flash(
294+
key, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
295+
)
296+
value_padded, _, _ = _reshape_data_for_flash(
297+
value, heads, blockwise_kwargs["key_chunk_size"], num_fsdp_shards
298+
)
299+
300+
# --- Step 3: Define Sharding and the Sharded Function ---
301+
# The PartitionSpec must match the transposed (b, s, h, d) layout.
302+
num_fsdp_devices = mesh.shape["fsdp"]
303+
if num_fsdp_devices % usp_degree != 0:
304+
raise ValueError("fsdp axis size must be divisible by usp_degree")
305+
ring_degree = num_fsdp_devices // usp_degree
306+
logical_mesh_shape = (mesh.shape['data'], usp_degree, ring_degree, mesh.shape["tensor"])
307+
reshaped_devices = mesh.devices.reshape(logical_mesh_shape)
308+
logical_mesh = Mesh(reshaped_devices, ('data', 'usp', 'ring', 'tensor'))
309+
310+
# 2.2: Define the KERNEL containing all sharded logic
311+
def _kernel(q_padded, k_padded, v_padded):
312+
# Step A: All-to-All Swap (This is a NO-OP if usp_degree=1)
313+
q_swapped = jax.lax.all_to_all(q_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
314+
k_swapped = jax.lax.all_to_all(k_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
315+
v_swapped = jax.lax.all_to_all(v_padded, 'usp', split_axis=1, concat_axis=2, tiled=True)
316+
317+
# Step B: Transpose the *swapped* tensors for the library
318+
q_t = rearrange(q_swapped, 'b h s d -> b s h d')
319+
k_t = rearrange(k_swapped, 'b h s d -> b s h d')
320+
v_t = rearrange(v_swapped, 'b h s d -> b s h d')
321+
322+
# Step C: Ring Attention (Always communicates along the 'ring' axis)
323+
attn_out_t = ringattention(q_t, k_t, v_t, axis_name='ring',
324+
blockwise_kwargs=blockwise_kwargs,
325+
attn_bias=None,
326+
segment_ids=None,
327+
cache_idx=None,
328+
float32_logits=True,
329+
)
330+
return attn_out_t
331+
332+
# 2.3: Define sharding for the inputs to the kernel
333+
initial_spec = PartitionSpec('data', 'tensor', ('usp', 'ring'), None)
334+
335+
# 2.4: Define sharding for the output of the kernel (which is transposed)
336+
output_spec = PartitionSpec('data', 'ring', ('usp', 'tensor'), None)
337+
338+
# 2.5: Create and call the sharded function
339+
sharded_attention_fn = shard_map.shard_map(
340+
_kernel,
341+
mesh=logical_mesh,
342+
in_specs=(initial_spec, initial_spec, initial_spec),
343+
out_specs=output_spec,
344+
check_rep=False
345+
)
346+
347+
attention_output_t = sharded_attention_fn(query_padded, key_padded, value_padded)
348+
349+
# --- Step 3: Reshard back to Physical Layout and Post-process ---
350+
351+
# 3.1: Transpose back to (b, h, s, d)
352+
attention_output = rearrange(attention_output_t, 'b s h d -> b h s d')
353+
354+
# 3.2: Define the target sharding using the ORIGINAL physical mesh
355+
physical_spec = PartitionSpec('data', 'tensor', 'fsdp', None)
356+
attention_output_resharded = jax.lax.with_sharding_constraint(
357+
attention_output,
358+
NamedSharding(mesh, physical_spec)
359+
)
360+
361+
# 3.3: Unpad and reshape
362+
attention_output_cropped = attention_output_resharded[:, :, :query_seq_len, :kv_size]
363+
final_output = _reshape_heads_to_head_dim(attention_output_cropped)
364+
365+
return final_output
366+
156367
def _tpu_flash_attention(
157368
query: jax.Array,
158369
key: jax.Array,
@@ -372,7 +583,7 @@ def _apply_attention(
372583
seq_len_idx = 1
373584
if query.ndim == 4:
374585
seq_len_idx = 2
375-
if attention_kernel == "flash":
586+
if attention_kernel in ["flash","ring_flash"]:
376587
can_use_flash_attention = (
377588
query.shape[seq_len_idx] >= flash_min_seq_length
378589
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -388,6 +599,11 @@ def _apply_attention(
388599
return _tpu_flash_attention(
389600
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
390601
)
602+
elif attention_kernel == "ring_flash":
603+
max_logging.log("USING RING ATTENTION")
604+
return _tpu_ring_flash_attention_v1(
605+
query, key , value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
606+
)
391607
elif attention_kernel == "cudnn_flash_te":
392608
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
393609
else:

0 commit comments

Comments
 (0)