@@ -108,7 +108,8 @@ def __init__(
108108 bias = attention_bias ,
109109 out_bias = attention_out_bias ,
110110 eps = norm_eps ,
111- dtype = dtype
111+ dtype = dtype ,
112+ mesh = mesh
112113 )
113114
114115 self .audio_norm1 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -121,7 +122,8 @@ def __init__(
121122 bias = attention_bias ,
122123 out_bias = attention_out_bias ,
123124 eps = norm_eps ,
124- dtype = dtype
125+ dtype = dtype ,
126+ mesh = mesh
125127 )
126128
127129 # 2. Prompt Cross-Attention
@@ -136,7 +138,8 @@ def __init__(
136138 bias = attention_bias ,
137139 out_bias = attention_out_bias ,
138140 eps = norm_eps ,
139- dtype = dtype
141+ dtype = dtype ,
142+ mesh = mesh
140143 )
141144
142145 self .audio_norm2 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -150,7 +153,8 @@ def __init__(
150153 bias = attention_bias ,
151154 out_bias = attention_out_bias ,
152155 eps = norm_eps ,
153- dtype = dtype
156+ dtype = dtype ,
157+ mesh = mesh
154158 )
155159
156160 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -165,7 +169,8 @@ def __init__(
165169 bias = attention_bias ,
166170 out_bias = attention_out_bias ,
167171 eps = norm_eps ,
168- dtype = dtype
172+ dtype = dtype ,
173+ mesh = mesh
169174 )
170175
171176 self .video_to_audio_norm = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -179,7 +184,8 @@ def __init__(
179184 bias = attention_bias ,
180185 out_bias = attention_out_bias ,
181186 eps = norm_eps ,
182- dtype = dtype
187+ dtype = dtype ,
188+ mesh = mesh
183189 )
184190
185191 # 4. Feed Forward
0 commit comments