@@ -383,6 +383,139 @@ def chunk_scanner(chunk_idx, _):
383383
384384 return jnp .concatenate (res , axis = - 3 ) # fuse the chunked result back
385385
386+ def apply_rope (xq : Array , xk : Array , freqs_cis : Array ) -> tuple [Array , Array ]:
387+ xq_ = xq .reshape (* xq .shape [:- 1 ], - 1 , 1 , 2 )
388+ xk_ = xk .reshape (* xk .shape [:- 1 ], - 1 , 1 , 2 )
389+
390+ xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
391+ xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
392+
393+ return xq_out .reshape (* xq .shape ).astype (xq .dtype ), xk_out .reshape (* xk .shape ).astype (xk .dtype )
394+
395+ class FlaxWanAttention (nn .module ):
396+ query_dim : int
397+ heads : int = 8
398+ dim_head : int = 64
399+ dropout : float = 0.0
400+ use_memory_efficient_attention : bool = False
401+ split_head_dim : bool = False
402+ attention_kernel : str = "dot_product"
403+ flash_min_seq_length : int = 4096
404+ flash_block_sizes : BlockSizes = None
405+ mesh : jax .sharding .Mesh = None
406+ dtype : jnp .dtype = jnp .float32
407+ weights_dtype : jnp .dtype = jnp .float32
408+ query_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
409+ key_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
410+ value_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
411+ out_axis_names : AxisNames = (BATCH , LENGTH , EMBED )
412+ precision : jax .lax .Precision = None
413+ qkv_bias : bool = False
414+
415+ def setup (self ):
416+ if self .attention_kernel in {"flash" , "cudnn_flash_te" } and self .mesh is None :
417+ raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
418+ inner_dim = self .dim_head * self .heads
419+ scale = self .dim_head ** - 0.5
420+
421+ self .attention_op = AttentionOp (
422+ mesh = self .mesh ,
423+ attention_kernel = self .attention_kernel ,
424+ scale = scale ,
425+ heads = self .heads ,
426+ dim_head = self .dim_head ,
427+ flash_min_seq_length = self .flash_min_seq_length ,
428+ use_memory_efficient_attention = self .use_memory_efficient_attention ,
429+ split_head_dim = self .split_head_dim ,
430+ flash_block_sizes = self .flash_block_sizes ,
431+ dtype = self .dtype ,
432+ float32_qk_product = False ,
433+ )
434+
435+ kernel_axes = ("embed" , "heads" )
436+ qkv_init_kernel = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes )
437+
438+ qkv_init_kernel = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed" , "heads" ))
439+
440+ self .query = nn .Dense (
441+ inner_dim ,
442+ kernel_init = qkv_init_kernel ,
443+ use_bias = False ,
444+ dtype = self .dtype ,
445+ param_dtype = self .weights_dtype ,
446+ name = "to_q" ,
447+ precision = self .precision ,
448+ )
449+
450+ self .key = nn .Dense (
451+ inner_dim ,
452+ kernel_init = qkv_init_kernel ,
453+ use_bias = False ,
454+ dtype = self .dtype ,
455+ param_dtype = self .weights_dtype ,
456+ name = "to_k" ,
457+ precision = self .precision ,
458+ )
459+
460+ self .value = nn .Dense (
461+ inner_dim ,
462+ kernel_init = qkv_init_kernel ,
463+ use_bias = False ,
464+ dtype = self .dtype ,
465+ param_dtype = self .weights_dtype ,
466+ name = "to_v" ,
467+ precision = self .precision ,
468+ )
469+
470+ self .query_norm = nn .RMSNorm (
471+ dtype = self .dtype ,
472+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
473+ param_dtype = self .weights_dtype ,
474+ )
475+ self .key_norm = nn .RMSNorm (
476+ dtype = self .dtype ,
477+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
478+ param_dtype = self .weights_dtype ,
479+ )
480+
481+ self .proj_attn = nn .Dense (
482+ self .query_dim ,
483+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("heads" , "embed" )),
484+ dtype = self .dtype ,
485+ param_dtype = self .weights_dtype ,
486+ name = "to_out_0" ,
487+ precision = self .precision ,
488+ )
489+ self .dropout_layer = nn .Dropout (rate = self .dropout )
490+
491+ def call (
492+ self ,
493+ hidden_states : Array ,
494+ encoder_hidden_states : Optional [Array ],
495+ rotary_emb : Optional [Array ],
496+ deterministic : bool = True
497+ ):
498+ encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
499+
500+ query_proj = self .query (hidden_states )
501+ key_proj = self .key (encoder_hidden_states )
502+ value_proj = self .value (encoder_hidden_states )
503+
504+ query_proj = self .query_norm (query_proj )
505+ key_proj = self .key_norm (key_proj )
506+
507+ if rotary_emb :
508+ query_proj , key_proj = self .apply_rope (query_proj , key_proj , rotary_emb )
509+
510+ query_proj = nn .with_logical_constraint (query_proj , self .query_axis_names )
511+ key_proj = nn .with_logical_constraint (key_proj , self .key_axis_names )
512+ value_proj = nn .with_logical_constraint (value_proj , self .value_axis_names )
513+
514+ hidden_states = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
515+
516+ hidden_states = self .proj_attn (hidden_states )
517+ hidden_states = nn .with_logical_constraint (hidden_states , (BATCH , LENGTH , HEAD ))
518+ return self .dropout_layer (hidden_states , deterministic = deterministic )
386519
387520class FlaxFluxAttention (nn .Module ):
388521 query_dim : int
@@ -493,15 +626,6 @@ def setup(self):
493626 param_dtype = self .weights_dtype ,
494627 )
495628
496- def apply_rope (self , xq : Array , xk : Array , freqs_cis : Array ) -> tuple [Array , Array ]:
497- xq_ = xq .reshape (* xq .shape [:- 1 ], - 1 , 1 , 2 )
498- xk_ = xk .reshape (* xk .shape [:- 1 ], - 1 , 1 , 2 )
499-
500- xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
501- xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
502-
503- return xq_out .reshape (* xq .shape ).astype (xq .dtype ), xk_out .reshape (* xk .shape ).astype (xk .dtype )
504-
505629 def __call__ (self , hidden_states , encoder_hidden_states = None , attention_mask = None , image_rotary_emb = None ):
506630
507631 qkv_proj = self .qkv (hidden_states )
@@ -535,7 +659,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
535659 value_proj = nn .with_logical_constraint (value_proj , self .value_axis_names )
536660
537661 image_rotary_emb = rearrange (image_rotary_emb , "n d (i j) -> n d i j" , i = 2 , j = 2 )
538- query_proj , key_proj = self . apply_rope (query_proj , key_proj , image_rotary_emb )
662+ query_proj , key_proj = apply_rope (query_proj , key_proj , image_rotary_emb )
539663
540664 query_proj = query_proj .transpose (0 , 2 , 1 , 3 ).reshape (query_proj .shape [0 ], query_proj .shape [2 ], - 1 )
541665 key_proj = key_proj .transpose (0 , 2 , 1 , 3 ).reshape (key_proj .shape [0 ], key_proj .shape [2 ], - 1 )
0 commit comments