@@ -612,6 +612,8 @@ def __init__(
612612 qk_norm : str = "rms_norm_across_heads" ,
613613 flash_block_sizes : BlockSizes = None ,
614614 flash_min_seq_length : int = 4096 ,
615+ gated_attn : bool = False ,
616+ cross_attn_mod : bool = False ,
615617 ** kwargs ,
616618 ):
617619 self .in_channels = in_channels
@@ -658,6 +660,8 @@ def __init__(
658660 self .names_which_can_be_offloaded = names_which_can_be_offloaded
659661 self .scan_layers = scan_layers
660662 self .attention_kernel = attention_kernel
663+ self .gated_attn = gated_attn
664+ self .cross_attn_mod = cross_attn_mod
661665 self .a2v_attention_kernel = a2v_attention_kernel
662666 self .v2a_attention_kernel = v2a_attention_kernel
663667 self .flash_min_seq_length = flash_min_seq_length
@@ -845,6 +849,8 @@ def init_block(rngs):
845849 norm_elementwise_affine = self .norm_elementwise_affine ,
846850 norm_eps = self .norm_eps ,
847851 rope_type = self .rope_type ,
852+ gated_attn = self .gated_attn ,
853+ cross_attn_mod = self .cross_attn_mod ,
848854 dtype = self .dtype ,
849855 weights_dtype = self .weights_dtype ,
850856 mesh = self .mesh ,
0 commit comments