Skip to content

Commit 3aed945

Browse files
committed
Add LTX2 Vocoder
1 parent cddbf6a commit 3aed945

2 files changed

Lines changed: 609 additions & 0 deletions

File tree

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import math
18+
from typing import Sequence
19+
20+
import jax
21+
import jax.numpy as jnp
22+
from flax import nnx
23+
from ... import common_types
24+
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
25+
from maxdiffusion.models.modeling_flax_utils import FlaxModelMixin
26+
27+
Array = common_types.Array
28+
DType = common_types.DType
29+
30+
31+
class ResBlock(nnx.Module):
32+
"""
33+
Residual Block for the LTX-2 Vocoder.
34+
"""
35+
36+
def __init__(
37+
self,
38+
channels: int,
39+
kernel_size: int = 3,
40+
stride: int = 1,
41+
dilations: Sequence[int] = (1, 3, 5),
42+
leaky_relu_negative_slope: float = 0.1,
43+
*,
44+
rngs: nnx.Rngs,
45+
dtype: DType = jnp.float32,
46+
):
47+
self.dilations = dilations
48+
self.negative_slope = leaky_relu_negative_slope
49+
50+
self.convs1 = nnx.List(
51+
[
52+
nnx.Conv(
53+
in_features=channels,
54+
out_features=channels,
55+
kernel_size=(kernel_size,),
56+
strides=(stride,),
57+
kernel_dilation=(dilation,),
58+
padding="SAME",
59+
rngs=rngs,
60+
dtype=dtype,
61+
)
62+
for dilation in dilations
63+
]
64+
)
65+
66+
self.convs2 = nnx.List(
67+
[
68+
nnx.Conv(
69+
in_features=channels,
70+
out_features=channels,
71+
kernel_size=(kernel_size,),
72+
strides=(stride,),
73+
kernel_dilation=(1,),
74+
padding="SAME",
75+
rngs=rngs,
76+
dtype=dtype,
77+
)
78+
for _ in range(len(dilations))
79+
]
80+
)
81+
82+
def __call__(self, x: Array) -> Array:
83+
for conv1, conv2 in zip(self.convs1, self.convs2):
84+
xt = jax.nn.leaky_relu(x, negative_slope=self.negative_slope)
85+
xt = conv1(xt)
86+
xt = jax.nn.leaky_relu(xt, negative_slope=self.negative_slope)
87+
xt = conv2(xt)
88+
x = x + xt
89+
return x
90+
91+
92+
class LTX2Vocoder(nnx.Module, FlaxModelMixin, ConfigMixin):
93+
"""
94+
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
95+
"""
96+
@register_to_config
97+
def __init__(
98+
self,
99+
in_channels: int = 128,
100+
hidden_channels: int = 1024,
101+
out_channels: int = 2,
102+
upsample_kernel_sizes: Sequence[int] = (16, 15, 8, 4, 4),
103+
upsample_factors: Sequence[int] = (6, 5, 2, 2, 2),
104+
resnet_kernel_sizes: Sequence[int] = (3, 7, 11),
105+
resnet_dilations: Sequence[Sequence[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
106+
leaky_relu_negative_slope: float = 0.1,
107+
# output_sampling_rate is unused in model structure but kept for config compat
108+
output_sampling_rate: int = 24000,
109+
*,
110+
rngs: nnx.Rngs,
111+
dtype: DType = jnp.float32,
112+
):
113+
self.num_upsample_layers = len(upsample_kernel_sizes)
114+
self.resnets_per_upsample = len(resnet_kernel_sizes)
115+
self.out_channels = out_channels
116+
self.total_upsample_factor = math.prod(upsample_factors)
117+
self.negative_slope = leaky_relu_negative_slope
118+
self.dtype = dtype
119+
120+
if self.num_upsample_layers != len(upsample_factors):
121+
raise ValueError(
122+
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
123+
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
124+
)
125+
126+
if self.resnets_per_upsample != len(resnet_dilations):
127+
raise ValueError(
128+
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
129+
f" {self.resnets_per_upsample} and {len(resnet_dilations)}, respectively."
130+
)
131+
132+
# PyTorch Conv1d expects (Batch, Channels, Length), we use (Batch, Length, Channels)
133+
# So in_channels/out_channels args are standard, but data layout is transposed in __call__
134+
self.conv_in = nnx.Conv(
135+
in_features=in_channels,
136+
out_features=hidden_channels,
137+
kernel_size=(7,),
138+
strides=(1,),
139+
padding="SAME",
140+
rngs=rngs,
141+
dtype=self.dtype,
142+
)
143+
144+
self.upsamplers = nnx.List()
145+
self.resnets = nnx.List()
146+
input_channels = hidden_channels
147+
148+
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
149+
output_channels = input_channels // 2
150+
151+
# ConvTranspose with padding='SAME' matches PyTorch's specific padding logic
152+
# for these standard HiFi-GAN upsampling configurations.
153+
self.upsamplers.append(
154+
nnx.ConvTranspose(
155+
in_features=input_channels,
156+
out_features=output_channels,
157+
kernel_size=(kernel_size,),
158+
strides=(stride,),
159+
padding="SAME",
160+
rngs=rngs,
161+
dtype=self.dtype,
162+
)
163+
)
164+
165+
for res_kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
166+
self.resnets.append(
167+
ResBlock(
168+
channels=output_channels,
169+
kernel_size=res_kernel_size,
170+
dilations=dilations,
171+
leaky_relu_negative_slope=leaky_relu_negative_slope,
172+
rngs=rngs,
173+
dtype=self.dtype,
174+
)
175+
)
176+
input_channels = output_channels
177+
178+
self.conv_out = nnx.Conv(
179+
in_features=input_channels,
180+
out_features=out_channels,
181+
kernel_size=(7,),
182+
strides=(1,),
183+
padding="SAME",
184+
rngs=rngs,
185+
dtype=self.dtype,
186+
)
187+
188+
def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
189+
"""
190+
Forward pass of the vocoder.
191+
192+
Args:
193+
hidden_states: Input Mel spectrogram tensor.
194+
Shape: `(B, C, T, F)` or `(B, C, F, T)`
195+
time_last: Legacy flag for input layout.
196+
197+
Returns:
198+
Audio waveform: `(B, OutChannels, AudioLength)`
199+
"""
200+
# Ensure layout: (Batch, Channels, MelBins, Time)
201+
if not time_last:
202+
hidden_states = jnp.transpose(hidden_states, (0, 1, 3, 2))
203+
204+
# Flatten Channels and MelBins -> (Batch, Features, Time)
205+
batch, channels, mel_bins, time = hidden_states.shape
206+
hidden_states = hidden_states.reshape(batch, channels * mel_bins, time)
207+
208+
# Transpose to (Batch, Time, Features) for Flax NWC Convolutions
209+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
210+
211+
hidden_states = self.conv_in(hidden_states)
212+
213+
for i in range(self.num_upsample_layers):
214+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=self.negative_slope)
215+
hidden_states = self.upsamplers[i](hidden_states)
216+
217+
# Accumulate ResNet outputs (Memory Optimization)
218+
start = i * self.resnets_per_upsample
219+
end = (i + 1) * self.resnets_per_upsample
220+
221+
res_sum = 0.0
222+
for j in range(start, end):
223+
res_sum = res_sum + self.resnets[j](hidden_states)
224+
225+
# Average the outputs (matches PyTorch mean(stack))
226+
hidden_states = res_sum / self.resnets_per_upsample
227+
228+
# Final Post-Processing
229+
# Note: using 0.01 slope here specifically (matches Diffusers implementation quirk)
230+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=0.01)
231+
hidden_states = self.conv_out(hidden_states)
232+
hidden_states = jnp.tanh(hidden_states)
233+
234+
# Transpose back to (Batch, Channels, Time) to match PyTorch/Diffusers output format
235+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
236+
237+
return hidden_states

0 commit comments

Comments
 (0)