@@ -322,6 +322,177 @@ def chunk_scanner(chunk_idx, _):
322322
323323 return jnp .concatenate (res , axis = - 3 ) # fuse the chunked result back
324324
325+ class FlaxFluxAttention (nn .Module ):
326+ query_dim : int
327+ heads : int = 8
328+ dim_head : int = 64
329+ dropout : float = 0.0
330+ use_memory_efficient_attention : bool = False
331+ split_head_dim : bool = False
332+ attention_kernel : str = "dot_product"
333+ flash_min_seq_length : int = 4096
334+ flash_block_sizes : BlockSizes = None
335+ mesh : jax .sharding .Mesh = None
336+ dtype : jnp .dtype = jnp .float32
337+ weights_dtype : jnp .dtype = jnp .float32
338+ query_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
339+ key_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
340+ value_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
341+ out_axis_names : AxisNames = (BATCH , LENGTH , EMBED )
342+ precision : jax .lax .Precision = None
343+ qkv_bias : bool = False
344+
345+ def setup (self ):
346+ if self .attention_kernel in {"flash" , "cudnn_flash_te" } and self .mesh is None :
347+ raise ValueError (f"The flash attention kernel requires a value for mesh, but mesh is { self .mesh } " )
348+ inner_dim = self .dim_head * self .heads
349+ scale = self .dim_head ** - 0.5
350+
351+ self .attention_op = AttentionOp (
352+ mesh = self .mesh ,
353+ attention_kernel = self .attention_kernel ,
354+ scale = scale ,
355+ heads = self .heads ,
356+ dim_head = self .dim_head ,
357+ flash_min_seq_length = self .flash_min_seq_length ,
358+ use_memory_efficient_attention = self .use_memory_efficient_attention ,
359+ split_head_dim = self .split_head_dim ,
360+ flash_block_sizes = self .flash_block_sizes ,
361+ dtype = self .dtype ,
362+ float32_qk_product = False ,
363+ )
364+
365+ kernel_axes = ("embed" , "heads" )
366+ qkv_init_kernel = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes )
367+
368+ self .qkv = nn .Dense (
369+ inner_dim * 3 ,
370+ kernel_init = qkv_init_kernel ,
371+ use_bias = self .qkv_bias ,
372+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads" ,)),
373+ dtype = self .dtype ,
374+ param_dtype = self .weights_dtype ,
375+ name = "i_qkv" ,
376+ precision = self .precision ,
377+ )
378+
379+ self .encoder_qkv = nn .Dense (
380+ inner_dim * 3 ,
381+ kernel_init = qkv_init_kernel ,
382+ use_bias = self .qkv_bias ,
383+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads" ,)),
384+ dtype = self .dtype ,
385+ param_dtype = self .weights_dtype ,
386+ name = "e_qkv" ,
387+ precision = self .precision ,
388+ )
389+
390+ self .proj_attn = nn .Dense (
391+ self .query_dim ,
392+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
393+ use_bias = True ,
394+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads" ,)),
395+ dtype = self .dtype ,
396+ param_dtype = self .weights_dtype ,
397+ name = "i_proj" ,
398+ precision = self .precision ,
399+ )
400+
401+ self .encoder_proj_attn = nn .Dense (
402+ self .query_dim ,
403+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
404+ use_bias = True ,
405+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads" ,)),
406+ dtype = self .dtype ,
407+ param_dtype = self .weights_dtype ,
408+ name = "e_proj" ,
409+ precision = self .precision ,
410+ )
411+
412+ self .query_norm = nn .RMSNorm (
413+ dtype = self .dtype ,
414+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
415+ param_dtype = self .weights_dtype ,
416+ )
417+ self .key_norm = nn .RMSNorm (
418+ dtype = self .dtype ,
419+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
420+ param_dtype = self .weights_dtype ,
421+ )
422+
423+ self .encoder_query_norm = nn .RMSNorm (
424+ dtype = self .dtype ,
425+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
426+ param_dtype = self .weights_dtype ,
427+ )
428+ self .encoder_key_norm = nn .RMSNorm (
429+ dtype = self .dtype ,
430+ scale_init = nn .with_logical_partitioning (nn .initializers .ones , ("heads" ,)),
431+ param_dtype = self .weights_dtype ,
432+ )
433+
434+ def apply_rope (self , xq : Array , xk : Array , freqs_cis : Array ) -> tuple [Array , Array ]:
435+ xq_ = xq .reshape (* xq .shape [:- 1 ], - 1 , 1 , 2 )
436+ xk_ = xk .reshape (* xk .shape [:- 1 ], - 1 , 1 , 2 )
437+
438+ xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
439+ xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
440+
441+ return xq_out .reshape (* xq .shape ).astype (xq .dtype ), xk_out .reshape (* xk .shape ).astype (xk .dtype )
442+
443+ def __call__ (self , hidden_states , encoder_hidden_states = None , attention_mask = None , image_rotary_emb = None ):
444+
445+ qkv_proj = self .qkv (hidden_states )
446+ B , L = hidden_states .shape [:2 ]
447+ H , D , K = self .heads , qkv_proj .shape [- 1 ] // (self .heads * 3 ), 3
448+ qkv_proj = qkv_proj .reshape (B , L , K , H , D ).transpose (2 , 0 , 3 , 1 , 4 )
449+ query_proj , key_proj , value_proj = qkv_proj
450+
451+ query_proj = self .query_norm (query_proj )
452+
453+ key_proj = self .key_norm (key_proj )
454+
455+ if encoder_hidden_states is not None :
456+
457+ encoder_qkv_proj = self .encoder_qkv (encoder_hidden_states )
458+ B , L = encoder_hidden_states .shape [:2 ]
459+ H , D , K = self .heads , encoder_qkv_proj .shape [- 1 ] // (self .heads * 3 ), 3
460+ encoder_qkv_proj = encoder_qkv_proj .reshape (B , L , K , H , D ).transpose (2 , 0 , 3 , 1 , 4 )
461+ encoder_query_proj , encoder_key_proj , encoder_value_proj = encoder_qkv_proj
462+
463+ encoder_query_proj = self .encoder_query_norm (encoder_query_proj )
464+
465+ encoder_key_proj = self .encoder_key_norm (encoder_key_proj )
466+
467+ query_proj = jnp .concatenate ((encoder_query_proj , query_proj ), axis = 2 )
468+ key_proj = jnp .concatenate ((encoder_key_proj , key_proj ), axis = 2 )
469+ value_proj = jnp .concatenate ((encoder_value_proj , value_proj ), axis = 2 )
470+
471+ query_proj = nn .with_logical_constraint (query_proj , self .query_axis_names )
472+ key_proj = nn .with_logical_constraint (key_proj , self .key_axis_names )
473+ value_proj = nn .with_logical_constraint (value_proj , self .value_axis_names )
474+
475+ image_rotary_emb = rearrange (image_rotary_emb , "n d (i j) -> n d i j" , i = 2 , j = 2 )
476+ query_proj , key_proj = self .apply_rope (query_proj , key_proj , image_rotary_emb )
477+
478+ query_proj = query_proj .transpose (0 , 2 , 1 , 3 ).reshape (query_proj .shape [0 ], query_proj .shape [2 ], - 1 )
479+ key_proj = key_proj .transpose (0 , 2 , 1 , 3 ).reshape (key_proj .shape [0 ], key_proj .shape [2 ], - 1 )
480+ value_proj = value_proj .transpose (0 , 2 , 1 , 3 ).reshape (value_proj .shape [0 ], value_proj .shape [2 ], - 1 )
481+
482+ attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
483+ context_attn_output = None
484+
485+ if encoder_hidden_states is not None :
486+ context_attn_output , attn_output = (
487+ attn_output [:, : encoder_hidden_states .shape [1 ]],
488+ attn_output [:, encoder_hidden_states .shape [1 ] :],
489+ )
490+
491+ attn_output = self .proj_attn (attn_output )
492+
493+ context_attn_output = self .encoder_proj_attn (context_attn_output )
494+
495+ return attn_output , context_attn_output
325496
326497class FlaxFluxAttention (nn .Module ):
327498 query_dim : int
0 commit comments