Skip to content

Commit 1da141d

Browse files
committed
Fix JAX convolution dimension numbers
1 parent 4628c02 commit 1da141d

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __call__(self, x: Array) -> Array:
100100
filter_expanded,
101101
window_strides=(self.ratio,),
102102
padding="VALID",
103-
dimension_numbers=('NLC', 'WIO', 'NLC'),
103+
dimension_numbers=('NLC', 'LIO', 'NLC'),
104104
feature_group_count=num_channels,
105105
)
106106
return x_filtered
@@ -156,7 +156,7 @@ def __call__(self, x: Array) -> Array:
156156
window_strides=(1,),
157157
padding="VALID",
158158
lhs_dilation=(self.ratio,),
159-
dimension_numbers=('NLC', 'WIO', 'NLC'),
159+
dimension_numbers=('NLC', 'LIO', 'NLC'),
160160
feature_group_count=num_channels,
161161
)
162162

@@ -486,7 +486,7 @@ def __call__(self, waveform: Array) -> tuple[Array, Array]:
486486
self.forward_basis.value,
487487
window_strides=(self.hop_length,),
488488
padding="VALID",
489-
dimension_numbers=('NLC', 'WIO', 'NLC'),
489+
dimension_numbers=('NLC', 'LIO', 'NLC'),
490490
)
491491

492492
n_freqs = spec.shape[-1] // 2

0 commit comments

Comments
 (0)