1818import flax .linen as nn
1919from flax import nnx
2020import jax
21- from jax .sharding import PartitionSpec
21+ from jax .sharding import PartitionSpec , NamedSharding , Mesh as JaxMesh
2222import jax .numpy as jnp
23- from jax .experimental import shard_map
23+ from jax import lax
24+ from jax .experimental .shard_map import shard_map
2425from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
2526from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2627from einops import rearrange
4344EMBED = common_types .EMBED
4445Quant = quantizations .AqtQuantization
4546
47+ # =========== START: USP Integration Code ===========
48+
49+ # --- Algorithm 2: Load Balancing for SP-Ring ---
50+ def prepare_load_balance_indices (global_seq_len , ring_degree ):
51+ """Computes the permutation indices for load balancing in ring attention."""
52+ if ring_degree == 1 :
53+ return jnp .arange (global_seq_len )
54+ num_chunks = 2 * ring_degree
55+ chunk_size = global_seq_len // num_chunks
56+ if global_seq_len % num_chunks != 0 :
57+ raise ValueError (f"Sequence length { global_seq_len } must be divisible by 2 * ring_degree { 2 * ring_degree } for load balancing." )
58+ chunks = jnp .arange (global_seq_len ).reshape (num_chunks , chunk_size )
59+ reordered_indices = []
60+ for i in range (ring_degree ):
61+ reordered_indices .append (chunks [i ])
62+ reordered_indices .append (chunks [num_chunks - 1 - i ])
63+ return jnp .concatenate (reordered_indices ).flatten ()
64+
4665
4766def _maybe_aqt_einsum (quant : Quant ):
4867 return jnp .einsum if quant is None else quant .einsum ()
@@ -167,7 +186,6 @@ def _tpu_flash_attention(
167186 block_q_dq = min (max_block_size , query .shape [2 ]),
168187 block_kv_dq = min (max_block_size , query .shape [2 ]),
169188 )
170-
171189 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q )
172190 key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute )
173191 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute )
@@ -460,6 +478,156 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
460478
461479 return xq_out .reshape (* xq .shape ).astype (xq .dtype ), xk_out .reshape (* xk .shape ).astype (xk .dtype )
462480
481+ class NNXUSPAttentionOp (nnx .Module ):
482+ def __init__ (
483+ self ,
484+ mesh : Mesh ,
485+ flash_block_sizes ,
486+ heads ,
487+ dtype = jnp .bfloat16 ,
488+ ):
489+ self .ulysses_degree = mesh .shape ['fsdp' ]
490+ self .ring_degree = mesh .shape ['tensor' ]
491+ self .mesh = mesh
492+ self .flash_block_sizes = flash_block_sizes
493+ self .heads = heads
494+ self .dtype = dtype
495+
496+ def apply_attention (self , query : Array , key : Array , value : Array ):
497+ flash_min_seq_length = 4096
498+ #breakpoint()
499+ can_use_flash_attention = (
500+ query .shape [2 ] >= flash_min_seq_length
501+ and key .shape [2 ] >= flash_min_seq_length
502+ and value .shape [2 ] >= flash_min_seq_length
503+ )
504+
505+ if not can_use_flash_attention :
506+ return _apply_attention_dot (
507+ query , key , value , jnp .bfloat16 , 40 , 128 , 128 ** - 0.5 , True , False , False
508+ )
509+
510+ num_heads_local_ulysses = self .heads // self .ulysses_degree
511+ # The mask shape should correspond to the local sequence length on each ring device
512+ # and the global sequence length after ring communication
513+ max_block_size = 1024 if self .dtype == jnp .bfloat16 else 512
514+ if self .flash_block_sizes :
515+ block_sizes = self .flash_block_sizes
516+ else :
517+ block_sizes = splash_attention_kernel .BlockSizes (
518+ block_q = min (max_block_size , query .shape [2 ]),
519+ block_kv_compute = min (max_block_size , key .shape [2 ]),
520+ block_kv = min (max_block_size , key .shape [2 ]),
521+ block_q_dkv = min (max_block_size , query .shape [2 ]),
522+ block_kv_dkv = min (max_block_size , key .shape [2 ]),
523+ block_kv_dkv_compute = min (max_block_size , query .shape [2 ]),
524+ block_q_dq = min (max_block_size , query .shape [2 ]),
525+ block_kv_dq = min (max_block_size , query .shape [2 ]),
526+ )
527+
528+ q_len_local_unpadded = query .shape [2 ]
529+ block_q_size = block_sizes .block_q
530+ # Calculate the padded length for the local query sequence.
531+ # This ensures q_len_padded is a multiple of block_q_size.
532+ q_len_padded = (q_len_local_unpadded + block_q_size - 1 ) // block_q_size * block_q_size
533+
534+ k_len_global_padded = q_len_padded * self .ring_degree
535+
536+ mask = splash_attention_mask .FullMask (_shape = (q_len_padded , k_len_global_padded ))
537+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = [mask ] * num_heads_local_ulysses )
538+
539+ query , kv_size , query_seq_len_original = _reshape_data_for_flash (query , self .heads , block_sizes .block_q )
540+ key , _ , _ = _reshape_data_for_flash (key , self .heads , block_sizes .block_kv_compute )
541+ value , _ , _ = _reshape_data_for_flash (value , self .heads , block_sizes .block_kv_compute )
542+ #breakpoint()
543+
544+ splash_kernel = splash_attention_kernel .make_splash_mha (
545+ mask = multi_head_mask ,
546+ head_shards = 1 ,
547+ q_seq_shards = 1 ,
548+ block_sizes = block_sizes
549+ )
550+
551+ @functools .partial (shard_map ,
552+ mesh = self .mesh ,
553+ in_specs = (
554+ PartitionSpec ('data' , None , ('fsdp' , 'tensor' ), None ), # Q
555+ PartitionSpec ('data' , None , ('fsdp' , 'tensor' ), None ), # K
556+ PartitionSpec ('data' , None , ('fsdp' , 'tensor' ), None ), # V
557+ ),
558+ out_specs = PartitionSpec ('data' , None , ('fsdp' , 'tensor' ), None ),
559+ check_rep = False
560+ )
561+ def usp_attention (q , k , v ):
562+ """
563+ Implements the Unified Sequence Parallelism attention following the standard order of operations.
564+ fsdp -> ulysses axis, tensor -> ring axis.
565+ """
566+ # 1. Ulysses Forward: Swap sequence sharding for head sharding over the 'fsdp' axis.
567+ # Input shape: [B, H, S_local, D], sharded on S (axis 2) over ('fsdp', 'tensor').
568+ # We split axis 2 (Sequence) and concatenate axis 1 (Heads).
569+ q_a2a = lax .all_to_all (q , 'fsdp' , split_axis = 2 , concat_axis = 1 , tiled = True )
570+ k_a2a = lax .all_to_all (k , 'fsdp' , split_axis = 2 , concat_axis = 1 , tiled = True )
571+ v_a2a = lax .all_to_all (v , 'fsdp' , split_axis = 2 , concat_axis = 1 , tiled = True )
572+ # Now, tensors are sharded on Heads (axis 1) over 'fsdp' and Sequence (axis 2) over 'tensor'.
573+ # Shape is now [B, H * fsdp_degree, S_local / fsdp_degree, D].
574+
575+ # 2. Ring Attention: Gather the full K and V for each sequence chunk over the 'tensor' axis.
576+ ring_axis_size = lax .psum (1 , 'tensor' )
577+ k_ring , v_ring = k_a2a , v_a2a
578+ all_k , all_v = [k_ring ], [v_ring ]
579+ for _ in range (ring_axis_size - 1 ):
580+ perm = [(j , (j - 1 + ring_axis_size ) % ring_axis_size ) for j in range (ring_axis_size )]
581+ k_ring = lax .ppermute (k_ring , 'tensor' , perm = perm )
582+ v_ring = lax .ppermute (v_ring , 'tensor' , perm = perm )
583+ all_k .append (k_ring )
584+ all_v .append (v_ring )
585+
586+ # Concatenate along the sequence axis (2) to create the full key/value for attention.
587+ full_k_ring = jnp .concatenate (list (reversed (all_k )), axis = 2 )
588+ full_v_ring = jnp .concatenate (list (reversed (all_v )), axis = 2 )
589+
590+ # 3. Local Attention Calculation
591+ # The query (q_a2a) attends to the fully-gathered keys/values (full_k_ring).
592+ attn_out_local = jax .vmap (splash_kernel )(q_a2a , full_k_ring , full_v_ring )
593+ # The output shape is the same as the query q_a2a: [B, H * fsdp_degree, S_local / fsdp_degree, D].
594+
595+ # 4. Ulysses Backward: Swap back from head sharding to sequence sharding.
596+ # This is the crucial step that reduces the head dimension.
597+ # We split axis 1 (Heads) and concatenate axis 2 (Sequence).
598+ attn_out_final = lax .all_to_all (attn_out_local , 'fsdp' , split_axis = 1 , concat_axis = 2 , tiled = True )
599+ # Final shape is [B, H, (S_local / fsdp_degree) * fsdp_degree, D] = [B, H, S_local, D].
600+
601+ return attn_out_final
602+
603+
604+ # 1. Permute data for load balancing
605+ global_seq_len = query .shape [2 ]
606+ lb_permutation = prepare_load_balance_indices (global_seq_len , self .ring_degree )
607+
608+ permuted_q = query [:, :, lb_permutation , :]
609+ permuted_k = key [:, :, lb_permutation , :]
610+ permuted_v = value [:, :, lb_permutation , :]
611+
612+ # 2. Define sharding for USP input
613+ # Input data is sharded across 'data' and 'fsdp' axes.
614+ # The sequence dim (axis 2) is split for the 'fsdp' dimension.
615+ # We assume the mesh is defined with ('data', 'fsdp', 'tensor') axes
616+ # The tensor shape is [B, H, S, D], so we shard S (axis 2) on ('fsdp', 'tensor')
617+ usp_input_sharding = NamedSharding (self .mesh , PartitionSpec ('data' , None , ('fsdp' , 'tensor' ), None ))
618+
619+ distributed_q = jax .device_put (permuted_q , usp_input_sharding )
620+ distributed_k = jax .device_put (permuted_k , usp_input_sharding )
621+ distributed_v = jax .device_put (permuted_v , usp_input_sharding )
622+
623+ # 3. Call the USP attention function
624+ attn_output = usp_attention (distributed_q , distributed_k , distributed_v )
625+ inverse_lb_permutation = jnp .argsort (lb_permutation )
626+ attn_output = attn_output [:, :, inverse_lb_permutation , :]
627+ attn_output = attn_output [:, :, :query_seq_len_original , :kv_size ]
628+ # Reshape output back to [B, S, H*D]
629+ attn_output = _reshape_heads_to_head_dim (attn_output )
630+ return attn_output
463631
464632class NNXAttentionOp (nnx .Module ):
465633
@@ -574,6 +742,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
574742 )
575743
576744
745+
577746class FlaxWanAttention (nnx .Module ):
578747
579748 def __init__ (
@@ -601,12 +770,15 @@ def __init__(
601770 precision : jax .lax .Precision = None ,
602771 qkv_bias : bool = False ,
603772 quant : Quant = None ,
773+ # USP parameters
774+ ulysses_degree : int = 1 ,
775+ ring_degree : int = 1 ,
604776 ):
605777 if attention_kernel == "cudnn_flash_te" :
606778 raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
607779
608780 if attention_kernel in {"flash" , "cudnn_flash_te" } and mesh is None :
609- raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self . mesh } " )
781+ raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { mesh } " )
610782 self .dim_head = dim_head
611783 self .heads = heads
612784 self .inner_dim = dim_head * heads
@@ -617,20 +789,28 @@ def __init__(
617789 self .value_axis_names = value_axis_names
618790 self .out_axis_names = out_axis_names
619791
620- self .attention_op = NNXAttentionOp (
621- mesh = mesh ,
622- attention_kernel = attention_kernel ,
623- scale = scale ,
624- heads = heads ,
625- dim_head = dim_head ,
626- use_memory_efficient_attention = use_memory_efficient_attention ,
627- split_head_dim = split_head_dim ,
628- float32_qk_product = False ,
629- flash_min_seq_length = flash_min_seq_length ,
630- flash_block_sizes = flash_block_sizes ,
631- dtype = dtype ,
632- quant = quant ,
633- )
792+ # Store USP parameters
793+ ulysses_degree = mesh .shape ['fsdp' ]
794+ ring_degree = mesh .shape ['tensor' ]
795+ use_usp = ulysses_degree > 1 or ring_degree > 1
796+ if use_usp :
797+ self .attention_op = NNXUSPAttentionOp (mesh = mesh ,heads = heads ,flash_block_sizes = flash_block_sizes )
798+ else :
799+ # Fallback to original attention op if not using USP
800+ self .attention_op = NNXAttentionOp (
801+ mesh = mesh ,
802+ attention_kernel = attention_kernel ,
803+ scale = scale ,
804+ heads = heads ,
805+ dim_head = dim_head ,
806+ use_memory_efficient_attention = use_memory_efficient_attention ,
807+ split_head_dim = split_head_dim ,
808+ float32_qk_product = False ,
809+ flash_min_seq_length = flash_min_seq_length ,
810+ flash_block_sizes = flash_block_sizes ,
811+ dtype = dtype ,
812+ quant = quant ,
813+ )
634814
635815 kernel_axes = ("embed" , "heads" )
636816 qkv_init_kernel = nnx .with_partitioning (nnx .initializers .lecun_normal (), kernel_axes )
@@ -714,8 +894,11 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
714894 def __call__ (
715895 self , hidden_states : jax .Array , encoder_hidden_states : jax .Array = None , rotary_emb : Optional [jax .Array ] = None
716896 ) -> jax .Array :
897+
717898 hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
718- encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
899+ if encoder_hidden_states is not None :
900+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
901+
719902 dtype = hidden_states .dtype
720903 if encoder_hidden_states is None :
721904 encoder_hidden_states = hidden_states
@@ -727,19 +910,24 @@ def __call__(
727910 if self .qk_norm :
728911 query_proj = self .norm_q (query_proj )
729912 key_proj = self .norm_k (key_proj )
913+
914+ # All inputs are unflattened to [B, H, S, D]
915+ query_proj = _unflatten_heads (query_proj , self .heads )
916+ key_proj = _unflatten_heads (key_proj , self .heads )
917+ value_proj = _unflatten_heads (value_proj , self .heads )
918+
730919 if rotary_emb is not None :
731- query_proj = _unflatten_heads (query_proj , self .heads )
732- key_proj = _unflatten_heads (key_proj , self .heads )
733- value_proj = _unflatten_heads (value_proj , self .heads )
734920 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
735- query_proj = jax .lax .with_sharding_constraint (query_proj , PartitionSpec ("data" , "tensor" , None , None ))
736- key_proj = jax .lax .with_sharding_constraint (key_proj , PartitionSpec ("data" , "tensor" , None , None ))
737- value_proj = jax .lax .with_sharding_constraint (value_proj , PartitionSpec ("data" , "tensor" , None , None ))
921+ query_proj = jax .lax .with_sharding_constraint (query_proj , PartitionSpec ("data" , None , "fsdp" , None ))
922+ key_proj = jax .lax .with_sharding_constraint (key_proj , PartitionSpec ("data" , None , "fsdp" , None ))
923+ value_proj = jax .lax .with_sharding_constraint (value_proj , PartitionSpec ("data" , None , "fsdp" , None ))
738924
739925 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
740- attn_output = jax .lax .with_sharding_constraint (attn_output , PartitionSpec ("data" , None , None ))
926+ #breakpoint()
927+ #attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, "fsdp", None))
741928
742929 attn_output = attn_output .astype (dtype = dtype )
930+ #breakpoint()
743931
744932 hidden_states = self .proj_attn (attn_output )
745933 return hidden_states
@@ -1391,4 +1579,4 @@ def setup(self):
13911579 def __call__ (self , hidden_states , deterministic = True ):
13921580 hidden_states = self .proj (hidden_states )
13931581 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
1394- return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
1582+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments