Skip to content

Commit 912d60c

Browse files
committed
Add WAN sharding notes and projection axis fixes
1 parent 229c7d9 commit 912d60c

5 files changed

Lines changed: 769 additions & 17 deletions

File tree

docs/potential_optimizations.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Potential Optimizations
2+
3+
## 1. Fix QKV Projection Sharding (Column-Parallel)
4+
5+
**Location:** `src/maxdiffusion/models/attention_flax.py``NNXAttentionBlock.__init__`
6+
7+
**Problem:**
8+
`hidden_states` last dim is constrained to `activation_heads``tensor` axis.
9+
Q/K/V kernels use `kernel_axes = ("embed", "heads")`, meaning the contracting dim maps to `embed``[context, fsdp]`.
10+
11+
This is a **sharding mismatch on the contracting dimension**: XLA must reshard before the matmul, adding unnecessary communication.
12+
13+
**Fix: Column-parallel QKV**
14+
Shard the weight *output* dim over `tensor` (heads), not the contracting dim. The output then naturally lands sharded over heads — exactly what attention needs — with no all-reduce in the forward pass.
15+
16+
```python
17+
# Current (broken — contracting dim mismatch)
18+
kernel_axes = ("embed", "heads")
19+
20+
# Fix: column-parallel — shard output over heads, leave contracting unrestricted
21+
kernel_axes = (None, "heads")
22+
# or if fsdp sharding of the contracting dim is desired:
23+
kernel_axes = ("fsdp", "heads")
24+
```
25+
26+
Also ensure `hidden_states` is not constrained to `activation_heads` on the embed dim before the QKV matmuls, or all-gather it first.
27+
28+
**Expected gain:** Eliminates a reshard/all-gather inserted by XLA on the contracting dimension for every Q, K, V projection in every attention block.

0 commit comments

Comments
 (0)