Skip to content

Integrate torchax custom attention kernel into ulysses#392

Open
eltsai wants to merge 1 commit intomainfrom
torchax_attention
Open

Integrate torchax custom attention kernel into ulysses#392
eltsai wants to merge 1 commit intomainfrom
torchax_attention

Conversation

@eltsai
Copy link
Copy Markdown
Collaborator

@eltsai eltsai commented Apr 27, 2026

Adding torchax path's custom kernel into ulysses (triggered when attention=ulysses_custom)

Inference time:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):      80.6s
  Compile:               186.9s
  ────────────────────────────────────────
  Inference:             167.2s
==================================================

@eltsai eltsai requested a review from entrpn as a code owner April 27, 2026 05:21
@github-actions
Copy link
Copy Markdown

Comment on lines +689 to +693
bq = 2048
bkv = 2048
bkv_compute = 1024
bkv_compute_in = 256
heads_per_tile = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating this to

bq = 4864
bkv = 1024
bkv_compute = 1024
bkv_compute_in = 1024
heads_per_tile = 1

and using this command gave me the following latency

  Load (checkpoint):     297.0s
  Compile:               219.8s
  ───────────────────────────────
  Inference:             147.4s

Comment thread src/maxdiffusion/models/custom_splash_attention.py
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxdiffusion/models/custom_splash_attention.py
Comment thread src/maxdiffusion/models/custom_splash_attention.py
Comment thread src/maxdiffusion/models/attention_flax.py Outdated
@eltsai eltsai force-pushed the torchax_attention branch from 56c76b8 to daf4a31 Compare April 27, 2026 20:21
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 27, 2026

Updated stats:

Accelerator Sharding E2E time log Video
v7x-8 dp2-context4-tp1 139.3s log Video
v7x-16 dp2-context8-tp1 70.2s log Video

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants