Skip to content

Commit d84ff96

Browse files
committed
testing with sharding changes
1 parent a86a81f commit d84ff96

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def __init__(
359359
# 1. Define Partitioned Initializers (Logical Axes)
360360
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361361
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
362-
# Q, K, V biases: [out_features (embed)]
363-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
362+
# Q, K, V biases: [out_features (heads)]
363+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

365365
# Out kernel: [in_features (heads), out_features (embed)]
366366
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
367-
# Out bias: [out_features (heads)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
367+
# Out bias: [out_features (embed)]
368+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
369369

370370
# Norm scales
371371
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

0 commit comments

Comments
 (0)