Skip to content

Commit c9af5a5

Browse files
committed
Add shape matching workaround in ResBlock
1 parent b069e8f commit c9af5a5

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ def __call__(self, x: Array) -> Array:
314314
xt = conv1(xt)
315315
xt = act2(xt)
316316
xt = conv2(xt)
317+
318+
if xt.shape[1] < x.shape[1]:
319+
xt = jnp.pad(xt, ((0, 0), (0, x.shape[1] - xt.shape[1]), (0, 0)), mode='edge')
320+
elif xt.shape[1] > x.shape[1]:
321+
xt = xt[:, :x.shape[1], :]
322+
317323
x = x + xt
318324
return x
319325

0 commit comments

Comments
 (0)