diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 7441a2038..398b0f473 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -359,13 +359,13 @@ def __init__( # 1. Define Partitioned Initializers (Logical Axes) # Q, K, V kernels: [in_features (embed), out_features (heads)] qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) - # Q, K, V biases: [out_features (embed)] - qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) + # Q, K, V biases: [out_features (heads)] + qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) # Out kernel: [in_features (heads), out_features (embed)] out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) - # Out bias: [out_features (heads)] - out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) + # Out bias: [out_features (embed)] + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) # Norm scales norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) diff --git a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py index 1b43457af..20436f42f 100644 --- a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py +++ b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py @@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2. in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs ) self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num)) - self.blur = BlurDownsample(dims=2, stride=den) + self.blur_down = BlurDownsample(dims=2, stride=den) def __call__(self, x: jax.Array) -> jax.Array: x = self.conv(x) x = self.pixel_shuffle(x) - x = self.blur(x) + x = self.blur_down(x) return x