Skip to content

Commit 6b72f95

Browse files
committed
upsampler fix
1 parent 46cae70 commit 6b72f95

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def __init__(
359359
# 1. Define Partitioned Initializers (Logical Axes)
360360
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361361
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
362-
# Q, K, V biases: [out_features (embed)]
363-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
362+
# Q, K, V biases: [out_features (heads)]
363+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

365365
# Out kernel: [in_features (heads), out_features (embed)]
366366
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
367-
# Out bias: [out_features (heads)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
367+
# Out bias: [out_features (embed)]
368+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
369369

370370
# Norm scales
371371
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2.
165165
in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs
166166
)
167167
self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num))
168-
self.blur = BlurDownsample(dims=2, stride=den)
168+
self.blur_down = BlurDownsample(dims=2, stride=den)
169169

170170
def __call__(self, x: jax.Array) -> jax.Array:
171171
x = self.conv(x)
172172
x = self.pixel_shuffle(x)
173-
x = self.blur(x)
173+
x = self.blur_down(x)
174174
return x
175175

176176

0 commit comments

Comments
 (0)