Skip to content

Commit 0dd14bb

Browse files
committed
params names changed in transformer file
1 parent 3e8589f commit 0dd14bb

1 file changed

Lines changed: 16 additions & 34 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,8 @@ def __init__(
107107
dropout=0.0,
108108
bias=attention_bias,
109109
out_bias=attention_out_bias,
110-
qk_norm=qk_norm,
111-
norm_eps=norm_eps,
112-
rope_type=rope_type,
113-
dtype=dtype,
114-
param_dtype=weights_dtype
110+
eps=norm_eps,
111+
dtype=dtype
115112
)
116113

117114
self.audio_norm1 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -123,81 +120,66 @@ def __init__(
123120
dropout=0.0,
124121
bias=attention_bias,
125122
out_bias=attention_out_bias,
126-
qk_norm=qk_norm,
127-
norm_eps=norm_eps,
128-
rope_type=rope_type,
129-
dtype=dtype,
130-
param_dtype=weights_dtype
123+
eps=norm_eps,
124+
dtype=dtype
131125
)
132126

133127
# 2. Prompt Cross-Attention
134128
self.norm2 = nnx.RMSNorm(self.dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
135129
self.attn2 = LTX2Attention(
136130
rngs=rngs,
137131
query_dim=dim,
138-
cross_attention_dim=cross_attention_dim,
132+
context_dim=cross_attention_dim,
139133
heads=num_attention_heads,
140134
dim_head=attention_head_dim,
141135
dropout=0.0,
142136
bias=attention_bias,
143137
out_bias=attention_out_bias,
144-
qk_norm=qk_norm,
145-
norm_eps=norm_eps,
146-
rope_type=rope_type,
147-
dtype=dtype,
148-
param_dtype=weights_dtype
138+
eps=norm_eps,
139+
dtype=dtype
149140
)
150141

151142
self.audio_norm2 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
152143
self.audio_attn2 = LTX2Attention(
153144
rngs=rngs,
154145
query_dim=audio_dim,
155-
cross_attention_dim=audio_cross_attention_dim,
146+
context_dim=audio_cross_attention_dim,
156147
heads=audio_num_attention_heads,
157148
dim_head=audio_attention_head_dim,
158149
dropout=0.0,
159150
bias=attention_bias,
160151
out_bias=attention_out_bias,
161-
qk_norm=qk_norm,
162-
norm_eps=norm_eps,
163-
rope_type=rope_type,
164-
dtype=dtype,
165-
param_dtype=weights_dtype
152+
eps=norm_eps,
153+
dtype=dtype
166154
)
167155

168156
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
169157
self.audio_to_video_norm = nnx.RMSNorm(dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
170158
self.audio_to_video_attn = LTX2Attention(
171159
rngs=rngs,
172160
query_dim=dim,
173-
cross_attention_dim=audio_dim,
161+
context_dim=audio_dim,
174162
heads=audio_num_attention_heads,
175163
dim_head=audio_attention_head_dim,
176164
dropout=0.0,
177165
bias=attention_bias,
178166
out_bias=attention_out_bias,
179-
qk_norm=qk_norm,
180-
norm_eps=norm_eps,
181-
rope_type=rope_type,
182-
dtype=dtype,
183-
param_dtype=weights_dtype
167+
eps=norm_eps,
168+
dtype=dtype
184169
)
185170

186171
self.video_to_audio_norm = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
187172
self.video_to_audio_attn = LTX2Attention(
188173
rngs=rngs,
189174
query_dim=audio_dim,
190-
cross_attention_dim=dim,
175+
context_dim=dim,
191176
heads=audio_num_attention_heads,
192177
dim_head=audio_attention_head_dim,
193178
dropout=0.0,
194179
bias=attention_bias,
195180
out_bias=attention_out_bias,
196-
qk_norm=qk_norm,
197-
norm_eps=norm_eps,
198-
rope_type=rope_type,
199-
dtype=dtype,
200-
param_dtype=weights_dtype
181+
eps=norm_eps,
182+
dtype=dtype
201183
)
202184

203185
# 4. Feed Forward

0 commit comments

Comments
 (0)