File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -159,7 +159,7 @@ def __init__(
159159 eps = norm_eps ,
160160 dtype = dtype ,
161161 mesh = mesh ,
162- attention_kernel = "dot_product" ,
162+ attention_kernel = self . attention_kernel ,
163163 rope_type = rope_type ,
164164 flash_block_sizes = flash_block_sizes ,
165165 )
@@ -212,7 +212,7 @@ def __init__(
212212 eps = norm_eps ,
213213 dtype = dtype ,
214214 mesh = mesh ,
215- attention_kernel = "dot_product" ,
215+ attention_kernel = self . attention_kernel ,
216216 rope_type = rope_type ,
217217 flash_block_sizes = flash_block_sizes ,
218218 )
@@ -239,7 +239,7 @@ def __init__(
239239 eps = norm_eps ,
240240 dtype = dtype ,
241241 mesh = mesh ,
242- attention_kernel = "dot_product" ,
242+ attention_kernel = self . attention_kernel ,
243243 rope_type = rope_type ,
244244 flash_block_sizes = flash_block_sizes ,
245245 )
@@ -265,7 +265,7 @@ def __init__(
265265 eps = norm_eps ,
266266 dtype = dtype ,
267267 mesh = mesh ,
268- attention_kernel = "dot_product" ,
268+ attention_kernel = self . attention_kernel ,
269269 rope_type = rope_type ,
270270 flash_block_sizes = flash_block_sizes ,
271271 )
You can’t perform that action at this time.
0 commit comments