@@ -107,11 +107,8 @@ def __init__(
107107 dropout = 0.0 ,
108108 bias = attention_bias ,
109109 out_bias = attention_out_bias ,
110- qk_norm = qk_norm ,
111- norm_eps = norm_eps ,
112- rope_type = rope_type ,
113- dtype = dtype ,
114- param_dtype = weights_dtype
110+ eps = norm_eps ,
111+ dtype = dtype
115112 )
116113
117114 self .audio_norm1 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
@@ -123,81 +120,66 @@ def __init__(
123120 dropout = 0.0 ,
124121 bias = attention_bias ,
125122 out_bias = attention_out_bias ,
126- qk_norm = qk_norm ,
127- norm_eps = norm_eps ,
128- rope_type = rope_type ,
129- dtype = dtype ,
130- param_dtype = weights_dtype
123+ eps = norm_eps ,
124+ dtype = dtype
131125 )
132126
133127 # 2. Prompt Cross-Attention
134128 self .norm2 = nnx .RMSNorm (self .dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
135129 self .attn2 = LTX2Attention (
136130 rngs = rngs ,
137131 query_dim = dim ,
138- cross_attention_dim = cross_attention_dim ,
132+ context_dim = cross_attention_dim ,
139133 heads = num_attention_heads ,
140134 dim_head = attention_head_dim ,
141135 dropout = 0.0 ,
142136 bias = attention_bias ,
143137 out_bias = attention_out_bias ,
144- qk_norm = qk_norm ,
145- norm_eps = norm_eps ,
146- rope_type = rope_type ,
147- dtype = dtype ,
148- param_dtype = weights_dtype
138+ eps = norm_eps ,
139+ dtype = dtype
149140 )
150141
151142 self .audio_norm2 = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
152143 self .audio_attn2 = LTX2Attention (
153144 rngs = rngs ,
154145 query_dim = audio_dim ,
155- cross_attention_dim = audio_cross_attention_dim ,
146+ context_dim = audio_cross_attention_dim ,
156147 heads = audio_num_attention_heads ,
157148 dim_head = audio_attention_head_dim ,
158149 dropout = 0.0 ,
159150 bias = attention_bias ,
160151 out_bias = attention_out_bias ,
161- qk_norm = qk_norm ,
162- norm_eps = norm_eps ,
163- rope_type = rope_type ,
164- dtype = dtype ,
165- param_dtype = weights_dtype
152+ eps = norm_eps ,
153+ dtype = dtype
166154 )
167155
168156 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
169157 self .audio_to_video_norm = nnx .RMSNorm (dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
170158 self .audio_to_video_attn = LTX2Attention (
171159 rngs = rngs ,
172160 query_dim = dim ,
173- cross_attention_dim = audio_dim ,
161+ context_dim = audio_dim ,
174162 heads = audio_num_attention_heads ,
175163 dim_head = audio_attention_head_dim ,
176164 dropout = 0.0 ,
177165 bias = attention_bias ,
178166 out_bias = attention_out_bias ,
179- qk_norm = qk_norm ,
180- norm_eps = norm_eps ,
181- rope_type = rope_type ,
182- dtype = dtype ,
183- param_dtype = weights_dtype
167+ eps = norm_eps ,
168+ dtype = dtype
184169 )
185170
186171 self .video_to_audio_norm = nnx .RMSNorm (audio_dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype )
187172 self .video_to_audio_attn = LTX2Attention (
188173 rngs = rngs ,
189174 query_dim = audio_dim ,
190- cross_attention_dim = dim ,
175+ context_dim = dim ,
191176 heads = audio_num_attention_heads ,
192177 dim_head = audio_attention_head_dim ,
193178 dropout = 0.0 ,
194179 bias = attention_bias ,
195180 out_bias = attention_out_bias ,
196- qk_norm = qk_norm ,
197- norm_eps = norm_eps ,
198- rope_type = rope_type ,
199- dtype = dtype ,
200- param_dtype = weights_dtype
181+ eps = norm_eps ,
182+ dtype = dtype
201183 )
202184
203185 # 4. Feed Forward
0 commit comments