@@ -364,33 +364,19 @@ def __init__(
364364 else :
365365 specs = sharding_specs
366366
367- # Use getattr with fallback to default specs if attribute is missing
368- def get_spec (attr_name ):
369- val = getattr (specs , attr_name , None )
370- if val is None :
371- default_specs = get_sharding_specs ("default" , "ltx2_dit" )
372- return getattr (default_specs , attr_name )
373- return val
374-
375- qkv_kernel = get_spec ("qkv_kernel" )
376- qkv_bias = get_spec ("qkv_bias" )
377- out_kernel = get_spec ("out_kernel" )
378- out_bias = get_spec ("out_bias" )
379- norm_scale = get_spec ("norm_scale" )
380-
381367 # 1. Define Partitioned Initializers (Logical Axes)
382368 # Q, K, V kernels: [in_features (embed), out_features (heads)]
383- qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), qkv_kernel )
369+ qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), specs . qkv_kernel )
384370 # Q, K, V biases: [out_features (heads)]
385- qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), qkv_bias )
371+ qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), specs . qkv_bias )
386372
387373 # Out kernel: [in_features (heads), out_features (embed)]
388- out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), out_kernel )
374+ out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), specs . out_kernel )
389375 # Out bias: [out_features (embed)]
390- out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), out_bias )
376+ out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), specs . out_bias )
391377
392378 # Norm scales
393- norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), norm_scale )
379+ norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), specs . norm_scale )
394380
395381 # 2. Projections
396382 self .to_q = nnx .Linear (
0 commit comments