Skip to content

Commit 7abca9a

Browse files
committed
fix(ltx2): add robust fallback for missing sharding specs in attention
1 parent 5175f20 commit 7abca9a

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,19 +364,33 @@ def __init__(
364364
else:
365365
specs = sharding_specs
366366

367+
# Use getattr with fallback to default specs if attribute is missing
368+
def get_spec(attr_name):
369+
val = getattr(specs, attr_name, None)
370+
if val is None:
371+
default_specs = get_sharding_specs("default", "ltx2_dit")
372+
return getattr(default_specs, attr_name)
373+
return val
374+
375+
qkv_kernel = get_spec("qkv_kernel")
376+
qkv_bias = get_spec("qkv_bias")
377+
out_kernel = get_spec("out_kernel")
378+
out_bias = get_spec("out_bias")
379+
norm_scale = get_spec("norm_scale")
380+
367381
# 1. Define Partitioned Initializers (Logical Axes)
368382
# Q, K, V kernels: [in_features (embed), out_features (heads)]
369-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
383+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_kernel)
370384
# Q, K, V biases: [out_features (heads)]
371-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)
385+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), qkv_bias)
372386

373387
# Out kernel: [in_features (heads), out_features (embed)]
374-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
388+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_kernel)
375389
# Out bias: [out_features (embed)]
376-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)
390+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias)
377391

378392
# Norm scales
379-
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)
393+
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), norm_scale)
380394

381395
# 2. Projections
382396
self.to_q = nnx.Linear(

0 commit comments

Comments
 (0)