Skip to content

Commit 70beed7

Browse files
authored
Changing the return type
1 parent d8cdbd5 commit 70beed7

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m
395395
check_rep=False,
396396
)
397397
def wrap_flash_attention(query, key, value):
398-
return jax.vmap(dpa_layer)(query, key, value, mask=None)
398+
return dpa_layer(query, key, value, mask=None)
399399

400400
out = wrap_flash_attention(query, key, value)
401401
return _reshape_data_from_cudnn_flash(out)

0 commit comments

Comments
 (0)