Skip to content

Commit acaf58d

Browse files
committed
Move filter generation to __call__ to force logs
1 parent 9b7cf8f commit acaf58d

1 file changed

Lines changed: 25 additions & 21 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,13 @@ def __init__(
100100
self.use_padding = use_padding
101101
self.padding_mode = padding_mode
102102

103-
cutoff = 0.5 / ratio
104-
half_width = 0.6 / ratio
105-
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
106-
print(f"DownSample1d filter - min: {low_pass_filter.min()}, max: {low_pass_filter.max()}")
107-
self.filter = jnp.expand_dims(low_pass_filter, axis=(1, 2))
108-
109103
def __call__(self, x: Array) -> Array:
110104
num_channels = x.shape[-1]
105+
106+
cutoff = 0.5 / self.ratio
107+
half_width = 0.6 / self.ratio
108+
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
109+
filter = jnp.expand_dims(low_pass_filter, axis=(1, 2))
111110
if self.use_padding:
112111
x = jnp.pad(x, ((0, 0), (self.pad_left, self.pad_right), (0, 0)), mode='edge')
113112

@@ -137,6 +136,8 @@ def __init__(
137136
self.ratio = ratio
138137
self.padding_mode = padding_mode
139138

139+
self.window_type = window_type
140+
140141
if window_type == "hann":
141142
rolloff = 0.99
142143
lowpass_filter_width = 6
@@ -145,30 +146,33 @@ def __init__(
145146
self.pad = width
146147
self.pad_left = 2 * width * ratio
147148
self.pad_right = self.kernel_size - ratio
148-
149-
time_axis = (jnp.arange(self.kernel_size) / ratio - width) * rolloff
150-
time_clamped = jnp.clip(time_axis, -lowpass_filter_width, lowpass_filter_width)
151-
window = jnp.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
152-
sinc_filter = jnp.sinc(time_axis) * window * rolloff / ratio
153-
self.filter = sinc_filter.reshape(-1, 1, 1)
154149
else:
155150
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
156151
self.pad = self.kernel_size // ratio - 1
157152
self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2
158153
self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2
159154

160-
sinc_filter = kaiser_sinc_filter1d(
161-
cutoff=0.5 / ratio,
162-
half_width=0.6 / ratio,
163-
kernel_size=self.kernel_size,
164-
)
165-
print(f"UpSample1d filter - min: {sinc_filter.min()}, max: {sinc_filter.max()}")
166-
self.filter = sinc_filter.reshape(-1, 1, 1)
167-
168155
def __call__(self, x: Array) -> Array:
169156
num_channels = x.shape[-1]
170157
batch, length, channels = x.shape
171158

159+
if self.window_type == "hann":
160+
rolloff = 0.99
161+
lowpass_filter_width = 6
162+
width = math.ceil(lowpass_filter_width / rolloff)
163+
time_axis = (jnp.arange(self.kernel_size) / self.ratio - width) * rolloff
164+
time_clamped = jnp.clip(time_axis, -lowpass_filter_width, lowpass_filter_width)
165+
window = jnp.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
166+
sinc_filter = jnp.sinc(time_axis) * window * rolloff / self.ratio
167+
filter = sinc_filter.reshape(-1, 1, 1)
168+
else:
169+
sinc_filter = kaiser_sinc_filter1d(
170+
cutoff=0.5 / self.ratio,
171+
half_width=0.6 / self.ratio,
172+
kernel_size=self.kernel_size,
173+
)
174+
filter = sinc_filter.reshape(-1, 1, 1)
175+
172176
# Interleave zeros (manual upsampling)
173177
x_expanded = jnp.zeros((batch, length * self.ratio, channels), dtype=x.dtype)
174178
x_expanded = x_expanded.at[:, ::self.ratio, :].set(x)
@@ -177,7 +181,7 @@ def __call__(self, x: Array) -> Array:
177181
pad_len = self.pad * self.ratio
178182
x_padded = jnp.pad(x_expanded, ((0, 0), (pad_len, pad_len), (0, 0)), mode='edge')
179183

180-
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
184+
filter_expanded = jnp.repeat(filter, num_channels, axis=2)
181185
filter_expanded = filter_expanded.astype(x.dtype)
182186

183187
x_upsampled = jax.lax.conv_general_dilated(

0 commit comments

Comments
 (0)