Skip to content

Commit db0c16f

Browse files
committed
added rope_type param to attention calls
1 parent df8a5fc commit db0c16f

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
dtype=dtype,
117117
mesh=mesh,
118118
attention_kernel=self.attention_kernel,
119+
rope_type=rope_type,
119120
)
120121

121122
self.audio_norm1 = nnx.RMSNorm(
@@ -138,6 +139,7 @@ def __init__(
138139
dtype=dtype,
139140
mesh=mesh,
140141
attention_kernel=self.attention_kernel,
142+
rope_type=rope_type,
141143
)
142144

143145
# 2. Prompt Cross-Attention
@@ -162,6 +164,7 @@ def __init__(
162164
dtype=dtype,
163165
mesh=mesh,
164166
attention_kernel=self.attention_kernel,
167+
rope_type=rope_type,
165168
)
166169

167170
self.audio_norm2 = nnx.RMSNorm(
@@ -185,6 +188,7 @@ def __init__(
185188
dtype=dtype,
186189
mesh=mesh,
187190
attention_kernel=self.attention_kernel,
191+
rope_type=rope_type,
188192
)
189193

190194
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -204,6 +208,7 @@ def __init__(
204208
dtype=dtype,
205209
mesh=mesh,
206210
attention_kernel=self.attention_kernel,
211+
rope_type=rope_type,
207212
)
208213

209214
self.video_to_audio_norm = nnx.RMSNorm(
@@ -227,6 +232,7 @@ def __init__(
227232
dtype=dtype,
228233
mesh=mesh,
229234
attention_kernel=self.attention_kernel,
235+
rope_type=rope_type,
230236
)
231237

232238
# 4. Feed Forward

0 commit comments

Comments
 (0)