Skip to content

Commit 79e2b86

Browse files
committed
qkv sharding based on device type
1 parent d5a8130 commit 79e2b86

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,23 +349,37 @@ def __init__(
349349
rope_type: str = "interleaved",
350350
flash_block_sizes: BlockSizes = None,
351351
flash_min_seq_length: int = 4096,
352+
qkv_sharding_spec: Optional[tuple] = None,
353+
out_sharding_spec: Optional[tuple] = None,
354+
out_bias_sharding_spec: Optional[tuple] = None,
352355
):
353356
self.heads = heads
354357
self.rope_type = rope_type
355358
self.dim_head = dim_head
356359
self.inner_dim = dim_head * heads
357360
self.dropout_rate = dropout
358361

362+
# Auto-detect hardware for sharding specs if not overridden
363+
device_kind = jax.devices()[0].device_kind
364+
is_ironwood = "7x" in device_kind
365+
366+
if qkv_sharding_spec is None:
367+
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
368+
if out_sharding_spec is None:
369+
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
370+
if out_bias_sharding_spec is None:
371+
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
372+
359373
# 1. Define Partitioned Initializers (Logical Axes)
360374
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads"))
375+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
362376
# Q, K, V biases: [out_features (heads)]
363377
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364378

365379
# Out kernel: [in_features (heads), out_features (embed)]
366-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", None))
380+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
367381
# Out bias: [out_features (embed)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), (None,))
382+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
369383

370384
# Norm scales
371385
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

0 commit comments

Comments
 (0)