Skip to content

Commit af5b89e

Browse files
committed
comment explaining hardware specific sharding
1 parent 752ce79 commit af5b89e

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def __init__(
364364
tpu_type = get_tpu_type()
365365
is_ironwood = tpu_type == TpuType.TPU_7X
366366

367+
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368+
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369+
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
367370
if qkv_sharding_spec is None:
368371
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
369372
if out_sharding_spec is None:

0 commit comments

Comments
 (0)