Commit c9d4109
committed
Fix LTX2 sharding: NNXSimpleFeedForward kernel axes and LTX2Attention bias axes
1. NNXSimpleFeedForward (used by LTX2 transformer blocks):
- net_0 (up-projection): kernel sharding fixed from ('embed', None) to
('embed', 'mlp'). The output dim should be sharded across tensor axis
to parallelize the computation. Previous sharding left the output
fully replicated, causing unnecessary all-gathers.
- net_2 (down-projection): kernel sharding fixed from ('embed', 'mlp')
to ('mlp', 'embed'). Input dim must match net_0's output sharding,
and output dim should use embed sharding. Previous sharding had the
axes reversed, creating resharding overhead.
- Bias axes updated to match their respective output dimensions.
2. LTX2Attention:
- QKV bias: fixed from ('embed',) to ('heads',) to match the QKV
kernel output dimension sharding.
- Output projection bias: fixed from ('heads',) to ('embed',) to match
the output kernel output dimension sharding.1 parent ceca471 commit c9d4109
2 files changed
Lines changed: 8 additions & 8 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
717 | 717 | | |
718 | 718 | | |
719 | 719 | | |
720 | | - | |
721 | | - | |
| 720 | + | |
| 721 | + | |
722 | 722 | | |
723 | 723 | | |
724 | 724 | | |
| |||
729 | 729 | | |
730 | 730 | | |
731 | 731 | | |
732 | | - | |
733 | | - | |
| 732 | + | |
| 733 | + | |
734 | 734 | | |
735 | 735 | | |
736 | 736 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
359 | 359 | | |
360 | 360 | | |
361 | 361 | | |
362 | | - | |
363 | | - | |
| 362 | + | |
| 363 | + | |
364 | 364 | | |
365 | 365 | | |
366 | 366 | | |
367 | | - | |
368 | | - | |
| 367 | + | |
| 368 | + | |
369 | 369 | | |
370 | 370 | | |
371 | 371 | | |
| |||
0 commit comments