Skip to content

Commit f2f7d2f

Browse files
wip - attempting to implement usp
1 parent 4587ff8 commit f2f7d2f

2 files changed

Lines changed: 229 additions & 41 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 215 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import flax.linen as nn
1919
from flax import nnx
2020
import jax
21-
from jax.sharding import PartitionSpec
21+
from jax.sharding import PartitionSpec, NamedSharding, Mesh as JaxMesh
2222
import jax.numpy as jnp
23-
from jax.experimental import shard_map
23+
from jax import lax
24+
from jax.experimental.shard_map import shard_map
2425
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2526
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2627
from einops import rearrange
@@ -43,6 +44,24 @@
4344
EMBED = common_types.EMBED
4445
Quant = quantizations.AqtQuantization
4546

47+
# =========== START: USP Integration Code ===========
48+
49+
# --- Algorithm 2: Load Balancing for SP-Ring ---
50+
def prepare_load_balance_indices(global_seq_len, ring_degree):
51+
"""Computes the permutation indices for load balancing in ring attention."""
52+
if ring_degree == 1:
53+
return jnp.arange(global_seq_len)
54+
num_chunks = 2 * ring_degree
55+
chunk_size = global_seq_len // num_chunks
56+
if global_seq_len % num_chunks != 0:
57+
raise ValueError(f"Sequence length {global_seq_len} must be divisible by 2 * ring_degree {2*ring_degree} for load balancing.")
58+
chunks = jnp.arange(global_seq_len).reshape(num_chunks, chunk_size)
59+
reordered_indices = []
60+
for i in range(ring_degree):
61+
reordered_indices.append(chunks[i])
62+
reordered_indices.append(chunks[num_chunks - 1 - i])
63+
return jnp.concatenate(reordered_indices).flatten()
64+
4665

4766
def _maybe_aqt_einsum(quant: Quant):
4867
return jnp.einsum if quant is None else quant.einsum()
@@ -167,7 +186,6 @@ def _tpu_flash_attention(
167186
block_q_dq=min(max_block_size, query.shape[2]),
168187
block_kv_dq=min(max_block_size, query.shape[2]),
169188
)
170-
171189
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q)
172190
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute)
173191
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute)
@@ -460,6 +478,156 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
460478

461479
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
462480

481+
class NNXUSPAttentionOp(nnx.Module):
482+
def __init__(
483+
self,
484+
mesh: Mesh,
485+
flash_block_sizes,
486+
heads,
487+
dtype = jnp.bfloat16,
488+
):
489+
self.ulysses_degree = mesh.shape['fsdp']
490+
self.ring_degree = mesh.shape['tensor']
491+
self.mesh = mesh
492+
self.flash_block_sizes = flash_block_sizes
493+
self.heads = heads
494+
self.dtype = dtype
495+
496+
def apply_attention(self, query: Array, key: Array, value: Array):
497+
flash_min_seq_length = 4096
498+
#breakpoint()
499+
can_use_flash_attention = (
500+
query.shape[2] >= flash_min_seq_length
501+
and key.shape[2] >= flash_min_seq_length
502+
and value.shape[2] >= flash_min_seq_length
503+
)
504+
505+
if not can_use_flash_attention:
506+
return _apply_attention_dot(
507+
query, key, value, jnp.bfloat16, 40, 128, 128**-0.5, True, False, False
508+
)
509+
510+
num_heads_local_ulysses = self.heads // self.ulysses_degree
511+
# The mask shape should correspond to the local sequence length on each ring device
512+
# and the global sequence length after ring communication
513+
max_block_size = 1024 if self.dtype == jnp.bfloat16 else 512
514+
if self.flash_block_sizes:
515+
block_sizes = self.flash_block_sizes
516+
else:
517+
block_sizes = splash_attention_kernel.BlockSizes(
518+
block_q=min(max_block_size, query.shape[2]),
519+
block_kv_compute=min(max_block_size, key.shape[2]),
520+
block_kv=min(max_block_size, key.shape[2]),
521+
block_q_dkv=min(max_block_size, query.shape[2]),
522+
block_kv_dkv=min(max_block_size, key.shape[2]),
523+
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
524+
block_q_dq=min(max_block_size, query.shape[2]),
525+
block_kv_dq=min(max_block_size, query.shape[2]),
526+
)
527+
528+
q_len_local_unpadded = query.shape[2]
529+
block_q_size = block_sizes.block_q
530+
# Calculate the padded length for the local query sequence.
531+
# This ensures q_len_padded is a multiple of block_q_size.
532+
q_len_padded = (q_len_local_unpadded + block_q_size - 1) // block_q_size * block_q_size
533+
534+
k_len_global_padded = q_len_padded * self.ring_degree
535+
536+
mask = splash_attention_mask.FullMask(_shape=(q_len_padded, k_len_global_padded))
537+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=[mask] * num_heads_local_ulysses)
538+
539+
query, kv_size, query_seq_len_original = _reshape_data_for_flash(query, self.heads, block_sizes.block_q)
540+
key, _, _ = _reshape_data_for_flash(key, self.heads, block_sizes.block_kv_compute)
541+
value, _, _ = _reshape_data_for_flash(value, self.heads, block_sizes.block_kv_compute)
542+
#breakpoint()
543+
544+
splash_kernel = splash_attention_kernel.make_splash_mha(
545+
mask=multi_head_mask,
546+
head_shards=1,
547+
q_seq_shards=1,
548+
block_sizes=block_sizes
549+
)
550+
551+
@functools.partial(shard_map,
552+
mesh=self.mesh,
553+
in_specs=(
554+
PartitionSpec('data', None, ('fsdp', 'tensor'), None), # Q
555+
PartitionSpec('data', None, ('fsdp', 'tensor'), None), # K
556+
PartitionSpec('data', None, ('fsdp', 'tensor'), None), # V
557+
),
558+
out_specs=PartitionSpec('data', None, ('fsdp', 'tensor'), None),
559+
check_rep=False
560+
)
561+
def usp_attention(q, k, v):
562+
"""
563+
Implements the Unified Sequence Parallelism attention following the standard order of operations.
564+
fsdp -> ulysses axis, tensor -> ring axis.
565+
"""
566+
# 1. Ulysses Forward: Swap sequence sharding for head sharding over the 'fsdp' axis.
567+
# Input shape: [B, H, S_local, D], sharded on S (axis 2) over ('fsdp', 'tensor').
568+
# We split axis 2 (Sequence) and concatenate axis 1 (Heads).
569+
q_a2a = lax.all_to_all(q, 'fsdp', split_axis=2, concat_axis=1, tiled=True)
570+
k_a2a = lax.all_to_all(k, 'fsdp', split_axis=2, concat_axis=1, tiled=True)
571+
v_a2a = lax.all_to_all(v, 'fsdp', split_axis=2, concat_axis=1, tiled=True)
572+
# Now, tensors are sharded on Heads (axis 1) over 'fsdp' and Sequence (axis 2) over 'tensor'.
573+
# Shape is now [B, H * fsdp_degree, S_local / fsdp_degree, D].
574+
575+
# 2. Ring Attention: Gather the full K and V for each sequence chunk over the 'tensor' axis.
576+
ring_axis_size = lax.psum(1, 'tensor')
577+
k_ring, v_ring = k_a2a, v_a2a
578+
all_k, all_v = [k_ring], [v_ring]
579+
for _ in range(ring_axis_size - 1):
580+
perm = [(j, (j - 1 + ring_axis_size) % ring_axis_size) for j in range(ring_axis_size)]
581+
k_ring = lax.ppermute(k_ring, 'tensor', perm=perm)
582+
v_ring = lax.ppermute(v_ring, 'tensor', perm=perm)
583+
all_k.append(k_ring)
584+
all_v.append(v_ring)
585+
586+
# Concatenate along the sequence axis (2) to create the full key/value for attention.
587+
full_k_ring = jnp.concatenate(list(reversed(all_k)), axis=2)
588+
full_v_ring = jnp.concatenate(list(reversed(all_v)), axis=2)
589+
590+
# 3. Local Attention Calculation
591+
# The query (q_a2a) attends to the fully-gathered keys/values (full_k_ring).
592+
attn_out_local = jax.vmap(splash_kernel)(q_a2a, full_k_ring, full_v_ring)
593+
# The output shape is the same as the query q_a2a: [B, H * fsdp_degree, S_local / fsdp_degree, D].
594+
595+
# 4. Ulysses Backward: Swap back from head sharding to sequence sharding.
596+
# This is the crucial step that reduces the head dimension.
597+
# We split axis 1 (Heads) and concatenate axis 2 (Sequence).
598+
attn_out_final = lax.all_to_all(attn_out_local, 'fsdp', split_axis=1, concat_axis=2, tiled=True)
599+
# Final shape is [B, H, (S_local / fsdp_degree) * fsdp_degree, D] = [B, H, S_local, D].
600+
601+
return attn_out_final
602+
603+
604+
# 1. Permute data for load balancing
605+
global_seq_len = query.shape[2]
606+
lb_permutation = prepare_load_balance_indices(global_seq_len, self.ring_degree)
607+
608+
permuted_q = query[:, :, lb_permutation, :]
609+
permuted_k = key[:, :, lb_permutation, :]
610+
permuted_v = value[:, :, lb_permutation, :]
611+
612+
# 2. Define sharding for USP input
613+
# Input data is sharded across 'data' and 'fsdp' axes.
614+
# The sequence dim (axis 2) is split for the 'fsdp' dimension.
615+
# We assume the mesh is defined with ('data', 'fsdp', 'tensor') axes
616+
# The tensor shape is [B, H, S, D], so we shard S (axis 2) on ('fsdp', 'tensor')
617+
usp_input_sharding = NamedSharding(self.mesh, PartitionSpec('data', None, ('fsdp', 'tensor'), None))
618+
619+
distributed_q = jax.device_put(permuted_q, usp_input_sharding)
620+
distributed_k = jax.device_put(permuted_k, usp_input_sharding)
621+
distributed_v = jax.device_put(permuted_v, usp_input_sharding)
622+
623+
# 3. Call the USP attention function
624+
attn_output = usp_attention(distributed_q, distributed_k, distributed_v)
625+
inverse_lb_permutation = jnp.argsort(lb_permutation)
626+
attn_output = attn_output[:, :, inverse_lb_permutation, :]
627+
attn_output = attn_output[:, :, :query_seq_len_original, :kv_size]
628+
# Reshape output back to [B, S, H*D]
629+
attn_output = _reshape_heads_to_head_dim(attn_output)
630+
return attn_output
463631

464632
class NNXAttentionOp(nnx.Module):
465633

@@ -574,6 +742,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
574742
)
575743

576744

745+
577746
class FlaxWanAttention(nnx.Module):
578747

579748
def __init__(
@@ -601,12 +770,15 @@ def __init__(
601770
precision: jax.lax.Precision = None,
602771
qkv_bias: bool = False,
603772
quant: Quant = None,
773+
# USP parameters
774+
ulysses_degree: int = 1,
775+
ring_degree: int = 1,
604776
):
605777
if attention_kernel == "cudnn_flash_te":
606778
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
607779

608780
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
609-
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
781+
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {mesh}")
610782
self.dim_head = dim_head
611783
self.heads = heads
612784
self.inner_dim = dim_head * heads
@@ -617,20 +789,28 @@ def __init__(
617789
self.value_axis_names = value_axis_names
618790
self.out_axis_names = out_axis_names
619791

620-
self.attention_op = NNXAttentionOp(
621-
mesh=mesh,
622-
attention_kernel=attention_kernel,
623-
scale=scale,
624-
heads=heads,
625-
dim_head=dim_head,
626-
use_memory_efficient_attention=use_memory_efficient_attention,
627-
split_head_dim=split_head_dim,
628-
float32_qk_product=False,
629-
flash_min_seq_length=flash_min_seq_length,
630-
flash_block_sizes=flash_block_sizes,
631-
dtype=dtype,
632-
quant=quant,
633-
)
792+
# Store USP parameters
793+
ulysses_degree = mesh.shape['fsdp']
794+
ring_degree = mesh.shape['tensor']
795+
use_usp = ulysses_degree > 1 or ring_degree > 1
796+
if use_usp:
797+
self.attention_op = NNXUSPAttentionOp(mesh=mesh,heads=heads,flash_block_sizes=flash_block_sizes)
798+
else:
799+
# Fallback to original attention op if not using USP
800+
self.attention_op = NNXAttentionOp(
801+
mesh=mesh,
802+
attention_kernel=attention_kernel,
803+
scale=scale,
804+
heads=heads,
805+
dim_head=dim_head,
806+
use_memory_efficient_attention=use_memory_efficient_attention,
807+
split_head_dim=split_head_dim,
808+
float32_qk_product=False,
809+
flash_min_seq_length=flash_min_seq_length,
810+
flash_block_sizes=flash_block_sizes,
811+
dtype=dtype,
812+
quant=quant,
813+
)
634814

635815
kernel_axes = ("embed", "heads")
636816
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
@@ -714,8 +894,11 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
714894
def __call__(
715895
self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None
716896
) -> jax.Array:
897+
717898
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
718-
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
899+
if encoder_hidden_states is not None:
900+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
901+
719902
dtype = hidden_states.dtype
720903
if encoder_hidden_states is None:
721904
encoder_hidden_states = hidden_states
@@ -727,19 +910,24 @@ def __call__(
727910
if self.qk_norm:
728911
query_proj = self.norm_q(query_proj)
729912
key_proj = self.norm_k(key_proj)
913+
914+
# All inputs are unflattened to [B, H, S, D]
915+
query_proj = _unflatten_heads(query_proj, self.heads)
916+
key_proj = _unflatten_heads(key_proj, self.heads)
917+
value_proj = _unflatten_heads(value_proj, self.heads)
918+
730919
if rotary_emb is not None:
731-
query_proj = _unflatten_heads(query_proj, self.heads)
732-
key_proj = _unflatten_heads(key_proj, self.heads)
733-
value_proj = _unflatten_heads(value_proj, self.heads)
734920
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
735-
query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None))
736-
key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None))
737-
value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None))
921+
query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", None, "fsdp", None))
922+
key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", None, "fsdp", None))
923+
value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", None, "fsdp", None))
738924

739925
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
740-
attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None))
926+
#breakpoint()
927+
#attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, "fsdp", None))
741928

742929
attn_output = attn_output.astype(dtype=dtype)
930+
#breakpoint()
743931

744932
hidden_states = self.proj_attn(attn_output)
745933
return hidden_states
@@ -1391,4 +1579,4 @@ def setup(self):
13911579
def __call__(self, hidden_states, deterministic=True):
13921580
hidden_states = self.proj(hidden_states)
13931581
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1394-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1582+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
6565

6666
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
6767
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
68-
69-
def create_model(rngs: nnx.Rngs, wan_config: dict):
70-
wan_transformer = WanModel(**wan_config, rngs=rngs)
71-
return wan_transformer
72-
7368
# 1. Load config.
7469
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
7570
wan_config["mesh"] = mesh
@@ -79,22 +74,27 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7974
wan_config["precision"] = get_precision(config)
8075
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
8176

82-
# 2. eval_shape - will not use flops or create weights on device
83-
# thus not using HBM memory.
84-
p_model_factory = partial(create_model, wan_config=wan_config)
85-
wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs)
86-
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
77+
# === START: MODIFIED CODE ===
78+
# 2. Instantiate the model on CPU to get its structure without using `eval_shape`.
79+
# This avoids all tracer leak issues by not performing a JAX transform during initialization.
80+
with jax.default_device(jax.devices('cpu')[0]):
81+
cpu_model = WanModel(**wan_config, rngs=rngs)
82+
83+
# 3. Split the CPU model to get the GraphDef and the State structure.
84+
graphdef, state, rest_of_state = nnx.split(cpu_model, nnx.Param, ...)
85+
86+
# Explicitly delete the CPU model to free up host memory.
87+
del cpu_model
88+
# === END: MODIFIED CODE ===
8789

88-
# 3. retrieve the state shardings, mapping logical names to mesh axis names.
90+
# 4. retrieve the state shardings, mapping logical names to mesh axis names.
8991
logical_state_spec = nnx.get_partition_spec(state)
9092
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
9193
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
9294
params = state.to_pure_dict()
9395
state = dict(nnx.to_flat_state(state))
9496

95-
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
96-
# This helps with loading sharded weights directly into the accelerators without fist copying them
97-
# all to one device and then distributing them, thus using low HBM memory.
97+
# 5. Load pretrained weights and move them to device using the state shardings from (4) above.
9898
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
9999
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100100
for path, val in flax.traverse_util.flatten_dict(params).items():

0 commit comments

Comments
 (0)