We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 752ce79 commit af5b89eCopy full SHA for af5b89e
1 file changed
src/maxdiffusion/models/ltx2/attention_ltx2.py
@@ -364,6 +364,9 @@ def __init__(
364
tpu_type = get_tpu_type()
365
is_ironwood = tpu_type == TpuType.TPU_7X
366
367
+ # Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368
+ # to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369
+ # This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
370
if qkv_sharding_spec is None:
371
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
372
if out_sharding_spec is None:
0 commit comments