@@ -92,10 +92,12 @@ def __init__(
9292 precision : jax .lax .Precision = None ,
9393 names_which_can_be_saved : list = [],
9494 names_which_can_be_offloaded : list = [],
95+ attention_kernel : str = "flash" ,
9596 ):
9697 self .dim = dim
9798 self .norm_eps = norm_eps
9899 self .norm_elementwise_affine = norm_elementwise_affine
100+ self .attention_kernel = attention_kernel
99101
100102 # 1. Self-Attention (video and audio)
101103 self .norm1 = nnx .RMSNorm (self .dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -109,7 +111,8 @@ def __init__(
109111 out_bias = attention_out_bias ,
110112 eps = norm_eps ,
111113 dtype = dtype ,
112- mesh = mesh
114+ mesh = mesh ,
115+ attention_kernel = self .attention_kernel
113116 )
114117
115118 self .audio_norm1 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -123,7 +126,8 @@ def __init__(
123126 out_bias = attention_out_bias ,
124127 eps = norm_eps ,
125128 dtype = dtype ,
126- mesh = mesh
129+ mesh = mesh ,
130+ attention_kernel = self .attention_kernel
127131 )
128132
129133 # 2. Prompt Cross-Attention
@@ -139,7 +143,8 @@ def __init__(
139143 out_bias = attention_out_bias ,
140144 eps = norm_eps ,
141145 dtype = dtype ,
142- mesh = mesh
146+ mesh = mesh ,
147+ attention_kernel = self .attention_kernel
143148 )
144149
145150 self .audio_norm2 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -154,7 +159,8 @@ def __init__(
154159 out_bias = attention_out_bias ,
155160 eps = norm_eps ,
156161 dtype = dtype ,
157- mesh = mesh
162+ mesh = mesh ,
163+ attention_kernel = self .attention_kernel
158164 )
159165
160166 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -170,7 +176,8 @@ def __init__(
170176 out_bias = attention_out_bias ,
171177 eps = norm_eps ,
172178 dtype = dtype ,
173- mesh = mesh
179+ mesh = mesh ,
180+ attention_kernel = self .attention_kernel
174181 )
175182
176183 self .video_to_audio_norm = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -185,7 +192,8 @@ def __init__(
185192 out_bias = attention_out_bias ,
186193 eps = norm_eps ,
187194 dtype = dtype ,
188- mesh = mesh
195+ mesh = mesh ,
196+ attention_kernel = self .attention_kernel
189197 )
190198
191199 # 4. Feed Forward
@@ -523,6 +531,7 @@ def __init__(
523531 names_which_can_be_saved : list = [],
524532 names_which_can_be_offloaded : list = [],
525533 scan_layers : bool = True ,
534+ attention_kernel : str = "flash" ,
526535 ):
527536 self .in_channels = in_channels
528537 self .out_channels = out_channels
@@ -568,6 +577,7 @@ def __init__(
568577 self .names_which_can_be_saved = names_which_can_be_saved
569578 self .names_which_can_be_offloaded = names_which_can_be_offloaded
570579 self .scan_layers = scan_layers
580+ self .attention_kernel = attention_kernel
571581
572582 _out_channels = self .out_channels or self .in_channels
573583 _audio_out_channels = self .audio_out_channels or self .audio_in_channels
@@ -723,6 +733,7 @@ def init_block(rngs):
723733 precision = self .precision ,
724734 names_which_can_be_saved = self .names_which_can_be_saved ,
725735 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
736+ attention_kernel = self .attention_kernel ,
726737 )
727738
728739 if self .scan_layers :
@@ -754,6 +765,7 @@ def init_block(rngs):
754765 precision = self .precision ,
755766 names_which_can_be_saved = self .names_which_can_be_saved ,
756767 names_which_can_be_offloaded = self .names_which_can_be_offloaded ,
768+ attention_kernel = self .attention_kernel ,
757769 )
758770 blocks .append (block )
759771 self .transformer_blocks = nnx .List (blocks )
0 commit comments