Skip to content

Commit 5f32283

Browse files
committed
Fix LTX2 sharding: NNXSimpleFeedForward kernel axes and LTX2Attention bias axes
1. NNXSimpleFeedForward (used by LTX2 transformer blocks): - net_0 (up-projection): kernel sharding fixed from ('embed', None) to ('embed', 'mlp'). The output dim should be sharded across tensor axis to parallelize the computation. Previous sharding left the output fully replicated, causing unnecessary all-gathers. - net_2 (down-projection): kernel sharding fixed from ('embed', 'mlp') to ('mlp', 'embed'). Input dim must match net_0's output sharding, and output dim should use embed sharding. Previous sharding had the axes reversed, creating resharding overhead. - Bias axes updated to match their respective output dimensions. 2. LTX2Attention: - QKV bias: fixed from ('embed',) to ('heads',) to match the QKV kernel output dimension sharding. - Output projection bias: fixed from ('heads',) to ('embed',) to match the output kernel output dimension sharding.
1 parent ceca471 commit 5f32283

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,8 @@ def __init__(
717717
dtype=dtype,
718718
param_dtype=weights_dtype,
719719
precision=precision,
720-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
721-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
720+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
721+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
722722
)
723723
self.act = get_activation(activation_fn)
724724
self.net_2 = nnx.Linear(
@@ -729,8 +729,8 @@ def __init__(
729729
dtype=dtype,
730730
param_dtype=weights_dtype,
731731
precision=precision,
732-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
733-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
732+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
733+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
734734
)
735735

736736
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def __init__(
359359
# 1. Define Partitioned Initializers (Logical Axes)
360360
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361361
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
362-
# Q, K, V biases: [out_features (embed)]
363-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
362+
# Q, K, V biases: [out_features (heads)]
363+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

365365
# Out kernel: [in_features (heads), out_features (embed)]
366366
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
367-
# Out bias: [out_features (heads)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
367+
# Out bias: [out_features (embed)]
368+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
369369

370370
# Norm scales
371371
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

0 commit comments

Comments
 (0)