Skip to content

Commit 691863b

Browse files
committed
changed the weight initialization for to_q, to_k, and to_v from (embed, heads) to (None, heads)
1 parent 0a36e73 commit 691863b

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def __init__(
358358

359359
# 1. Define Partitioned Initializers (Logical Axes)
360360
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
361+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads"))
362362
# Q, K, V biases: [out_features (heads)]
363363
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

0 commit comments

Comments
 (0)