Skip to content

Commit b069e8f

Browse files
committed
Fix string padding for transposed conv in JAX
1 parent 1da141d commit b069e8f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __call__(self, x: Array) -> Array:
9999
x,
100100
filter_expanded,
101101
window_strides=(self.ratio,),
102-
padding="VALID",
102+
padding=((0, 0),),
103103
dimension_numbers=('NLC', 'LIO', 'NLC'),
104104
feature_group_count=num_channels,
105105
)
@@ -154,7 +154,7 @@ def __call__(self, x: Array) -> Array:
154154
x,
155155
filter_expanded,
156156
window_strides=(1,),
157-
padding="VALID",
157+
padding=((0, 0),),
158158
lhs_dilation=(self.ratio,),
159159
dimension_numbers=('NLC', 'LIO', 'NLC'),
160160
feature_group_count=num_channels,
@@ -485,7 +485,7 @@ def __call__(self, waveform: Array) -> tuple[Array, Array]:
485485
waveform,
486486
self.forward_basis.value,
487487
window_strides=(self.hop_length,),
488-
padding="VALID",
488+
padding=((0, 0),),
489489
dimension_numbers=('NLC', 'LIO', 'NLC'),
490490
)
491491

0 commit comments

Comments
 (0)