Skip to content

Commit 77bd6dd

Browse files
committed
feat(ltx2): use cleaner approach for connector attention sharding
1 parent 7abca9a commit 77bd6dd

2 files changed

Lines changed: 5 additions & 20 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -364,33 +364,19 @@ def __init__(
364364
else:
365365
specs = sharding_specs
366366

367-
# Use getattr with fallback to default specs if attribute is missing
368-
def get_spec(attr_name):
369-
val = getattr(specs, attr_name, None)
370-
if val is None:
371-
default_specs = get_sharding_specs("default", "ltx2_dit")
372-
return getattr(default_specs, attr_name)
373-
return val
374-
375-
qkv_kernel = get_spec("qkv_kernel")
376-
qkv_bias = get_spec("qkv_bias")
377-
out_kernel = get_spec("out_kernel")
378-
out_bias = get_spec("out_bias")
379-
norm_scale = get_spec("norm_scale")
380-
381367
# 1. Define Partitioned Initializers (Logical Axes)
382368
# Q, K, V kernels: [in_features (embed), out_features (heads)]
383-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_kernel)
369+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
384370
# Q, K, V biases: [out_features (heads)]
385-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), qkv_bias)
371+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)
386372

387373
# Out kernel: [in_features (heads), out_features (embed)]
388-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_kernel)
374+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
389375
# Out bias: [out_features (embed)]
390-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias)
376+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)
391377

392378
# Norm scales
393-
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), norm_scale)
379+
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)
394380

395381
# 2. Projections
396382
self.to_q = nnx.Linear(

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(
4949
attention_kernel=attention_kernel,
5050
mesh=mesh,
5151
rngs=rngs,
52-
sharding_specs=sharding_specs,
5352
)
5453
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh", sharding_specs=sharding_specs)
5554
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)

0 commit comments

Comments
 (0)