Skip to content

Commit b6fe087

Browse files
committed
transformer and vocoder keys
1 parent 9a66fbe commit b6fe087

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,19 @@ def rename_for_ltx2_transformer(key):
9393
if "to_out_0" in key:
9494
key = key.replace("to_out_0", "to_out")
9595

96+
# Add missing mappings
97+
key = key.replace("av_ca_video_scale_shift_adaln_single", "av_cross_attn_video_scale_shift")
98+
key = key.replace("av_ca_a2v_gate_adaln_single", "av_cross_attn_video_a2v_gate")
99+
key = key.replace("av_ca_audio_scale_shift_adaln_single", "av_cross_attn_audio_scale_shift")
100+
key = key.replace("av_ca_v2a_gate_adaln_single", "av_cross_attn_audio_v2a_gate")
101+
key = key.replace("scale_shift_table_a2v_ca_video", "video_a2v_cross_attn_scale_shift_table")
102+
key = key.replace("scale_shift_table_a2v_ca_audio", "audio_a2v_cross_attn_scale_shift_table")
103+
96104
# LTX-2.3 specific mappings
105+
# Handle substrings before they are replaced by shorter patterns below
106+
key = key.replace("audio_prompt_adaln_single", "audio_prompt_adaln")
107+
key = key.replace("prompt_adaln_single", "prompt_adaln")
108+
97109
if "prompt_adaln" in key:
98110
key = key.replace("prompt_adaln", "caption_projection")
99111
if "audio_prompt_adaln" in key:
@@ -337,6 +349,8 @@ def rename_for_ltx2_vocoder(key):
337349
key = key.replace("ups.", "upsamplers.")
338350
key = key.replace("resblocks", "resnets")
339351
key = key.replace("conv_post", "conv_out")
352+
key = key.replace("conv_pre", "conv_in")
353+
key = key.replace("act_post", "act_out")
340354

341355
# LTX-2.3 specific mappings for Vocoder
342356
if "downsample" in key and "lowpass" not in key:

src/maxdiffusion/models/ltx2/vocoder_bwe_ltx2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,9 @@ def __init__(
384384
dtype=self.dtype,
385385
)
386386

387-
self.ups = nnx.List()
387+
self.upsamplers = nnx.List()
388388
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
389-
self.ups.append(
389+
self.upsamplers.append(
390390
nnx.ConvTranspose(
391391
in_features=upsample_initial_channel // (2**i),
392392
out_features=upsample_initial_channel // (2 ** (i + 1)),
@@ -398,11 +398,11 @@ def __init__(
398398
)
399399
)
400400

401-
self.resblocks = nnx.List()
401+
self.resnets = nnx.List()
402402
for i in range(len(upsample_rates)):
403403
ch = upsample_initial_channel // (2 ** (i + 1))
404404
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
405-
self.resblocks.append(AMPBlock1(ch, kernel_size, dilations, activation=activation, rngs=rngs))
405+
self.resnets.append(AMPBlock1(ch, kernel_size, dilations, activation=activation, rngs=rngs))
406406

407407
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
408408
self.act_out = Activation1d(final_channels, SnakeBeta(final_channels, rngs=rngs), rngs=rngs)
@@ -429,14 +429,14 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
429429
hidden_states = self.conv_in(hidden_states)
430430

431431
for i in range(self.num_upsamples):
432-
hidden_states = self.ups[i](hidden_states)
432+
hidden_states = self.upsamplers[i](hidden_states)
433433

434434
start = i * self.num_kernels
435435
end = (i + 1) * self.num_kernels
436436

437437
res_sum = 0.0
438438
for j in range(start, end):
439-
res_sum = res_sum + self.resblocks[j](hidden_states)
439+
res_sum = res_sum + self.resnets[j](hidden_states)
440440

441441
hidden_states = res_sum / self.num_kernels
442442

0 commit comments

Comments
 (0)