@@ -364,19 +364,33 @@ def __init__(
364364 else :
365365 specs = sharding_specs
366366
367+ # Use getattr with fallback to default specs if attribute is missing
368+ def get_spec (attr_name ):
369+ val = getattr (specs , attr_name , None )
370+ if val is None :
371+ default_specs = get_sharding_specs ("default" , "ltx2_dit" )
372+ return getattr (default_specs , attr_name )
373+ return val
374+
375+ qkv_kernel = get_spec ("qkv_kernel" )
376+ qkv_bias = get_spec ("qkv_bias" )
377+ out_kernel = get_spec ("out_kernel" )
378+ out_bias = get_spec ("out_bias" )
379+ norm_scale = get_spec ("norm_scale" )
380+
367381 # 1. Define Partitioned Initializers (Logical Axes)
368382 # Q, K, V kernels: [in_features (embed), out_features (heads)]
369- qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), specs . qkv_kernel )
383+ qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), qkv_kernel )
370384 # Q, K, V biases: [out_features (heads)]
371- qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), specs . qkv_bias )
385+ qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), qkv_bias )
372386
373387 # Out kernel: [in_features (heads), out_features (embed)]
374- out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), specs . out_kernel )
388+ out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), out_kernel )
375389 # Out bias: [out_features (embed)]
376- out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), specs . out_bias )
390+ out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), out_bias )
377391
378392 # Norm scales
379- norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), specs . norm_scale )
393+ norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), norm_scale )
380394
381395 # 2. Projections
382396 self .to_q = nnx .Linear (
0 commit comments