-
Notifications
You must be signed in to change notification settings - Fork 70
Fix bug in assigning block_q_dq when block sizes are not specified #282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 14 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
8ea69b8
Test on attention type and automatically modify flash block sizes obj…
coolkp a79a49c
Test on attention type and automatically modify flash block sizes obj…
coolkp 436e7d1
Test on attention type and automatically modify flash block sizes obj…
coolkp 721b4c6
Parametrize test
coolkp 1c36022
Parametrize test
coolkp becfc88
Test on attention type and automatically modify flash block sizes obj…
coolkp 8d8d1a0
Test on attention type and automatically modify flash block sizes obj…
coolkp a71eb74
Test on attention type and automatically modify flash block sizes obj…
coolkp 7238105
Test on attention type and automatically modify flash block sizes obj…
coolkp c379559
Test on attention type and automatically modify flash block sizes obj…
coolkp 3386550
Test on attention type and automatically modify flash block sizes obj…
coolkp 0dda0cf
Test on attention type and automatically modify flash block sizes obj…
coolkp 50249c4
Document attention behavior
coolkp e221517
Remove comment
coolkp 77c0ea7
Fix typo
coolkp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # Attention block sizes | ||
|
|
||
| ## Description | ||
| - "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass | ||
| - "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv" | ||
| - "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass | ||
| - "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q | ||
| - "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv | ||
| - "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv" | ||
| - "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q" | ||
| - "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv" | ||
| - "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead. | ||
|
|
||
| ## Flowchart | ||
|
|
||
| Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes. | ||
|
|
||
|  | ||
|
|
||
| > "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used | ||
|
|
||
| ## How block sizes matter for perfomance and accuracy | ||
|
|
||
| Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly reccomended to tune them. | ||
|
|
||
| Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes. | ||
|
|
||
| > In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values. | ||
|
|
||
| > KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed. | ||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - mispelled recommended.