Skip to content

Commit 82edc64

Browse files
add flash attention
1 parent b84fc34 commit 82edc64

1 file changed

Lines changed: 76 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from jax.experimental import shard_map
2525
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2626
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
27+
from jax.experimental.pallas.ops.tpu import flash_attention
28+
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention
2729
from einops import rearrange
2830
from .. import common_types, max_logging
2931

@@ -103,7 +105,7 @@ def _unflatten_heads(tensor, heads):
103105
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
104106
return tensor
105107

106-
def _reshape_data_for_flash(tensor, heads, flash_block_size):
108+
def _reshape_data_for_flash(tensor, heads, flash_block_size, pad=True):
107109
"""
108110
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
109111
"""
@@ -127,12 +129,28 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
127129
# pad to the closest multiplier of flash_block_size
128130
seq_len_pad = (mul + 1) * flash_block_size - seq_len
129131

130-
if kv_size < 128 or rem != 0:
132+
if kv_size < 128 or rem != 0 and pad:
131133
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
132134
tensor = jnp.pad(tensor, npad)
133135

134136
return tensor, kv_size, seq_len
135137

138+
def default_block_sizes(query: jax.Array, key: jax.Array, dtype: jnp.dtype):
139+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
140+
return flash_attention.BlockSizes(
141+
block_q=min(max_block_size, query.shape[-2]),
142+
block_k_major=min(max_block_size, key.shape[-2]),
143+
block_k=min(max_block_size, key.shape[-2]),
144+
block_b=min(1, query.shape[0]),
145+
block_q_major_dkv=min(max_block_size, query.shape[-2]),
146+
block_k_major_dkv=min(max_block_size, key.shape[-2]),
147+
block_q_dkv=min(max_block_size, query.shape[-2]),
148+
block_k_dkv=min(max_block_size, key.shape[-2]),
149+
block_q_dq=min(max_block_size, query.shape[-2]),
150+
block_k_dq=min(512, key.shape[-2]),
151+
block_k_major_dq=min(max_block_size, key.shape[-2]),
152+
)
153+
136154
def _tpu_flash_attention(
137155
query: jax.Array,
138156
key: jax.Array,
@@ -144,6 +162,59 @@ def _tpu_flash_attention(
144162
dtype: jnp.dtype = jnp.float32) -> jax.Array:
145163
"""TPU Flash Attention"""
146164

165+
if flash_block_sizes is None:
166+
block_sizes = default_block_sizes(query, key, dtype)
167+
168+
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, pad=True)
169+
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_k, pad=True)
170+
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_k, pad=True)
171+
172+
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
173+
@functools.partial(
174+
shard_map.shard_map,
175+
mesh=mesh,
176+
in_specs=(
177+
axis_names,
178+
axis_names,
179+
axis_names,
180+
),
181+
out_specs=axis_names,
182+
check_rep=False,
183+
)
184+
def wrap_flash_attention(query, key, value):
185+
output = jax_flash_attention(
186+
query,
187+
key,
188+
value,
189+
block_sizes=block_sizes
190+
)
191+
return output
192+
193+
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
194+
# This warning might show up when doing model eval for example, when calculating model flops
195+
# and that is expected.
196+
if not (query.shape[0] / devices_in_data_fsdp).is_integer():
197+
max_logging.log(
198+
"Warning, batch dimension should be shardable among the devices in data and fsdp"
199+
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
200+
)
201+
x = wrap_flash_attention(query, key, value)
202+
x = x[:, :, :query_seq_len, :kv_size]
203+
x = _reshape_heads_to_head_dim(x)
204+
return x
205+
206+
207+
def _tpu_splash_attention(
208+
query: jax.Array,
209+
key: jax.Array,
210+
value: jax.Array,
211+
heads: int,
212+
mesh: Mesh,
213+
flash_axis_names: AxisNames,
214+
flash_block_sizes: BlockSizes,
215+
dtype: jnp.dtype = jnp.float32) -> jax.Array:
216+
"""TPU Flash Attention"""
217+
147218
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
148219
if flash_block_sizes:
149220
block_sizes = flash_block_sizes
@@ -182,6 +253,7 @@ def wrap_flash_attention(query, key, value):
182253
splash_kernel = splash_attention_kernel.make_splash_mha(
183254
mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
184255
)
256+
#return splash_kernel(query, key, value)
185257
return jax.vmap(splash_kernel)(query, key, value)
186258

187259
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
@@ -344,6 +416,8 @@ def _apply_attention(
344416
return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention)
345417
elif attention_kernel == "flash":
346418
return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype)
419+
elif attention_kernel == "splash":
420+
return _tpu_splash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype)
347421
elif attention_kernel == "cudnn_flash_te":
348422
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
349423
else:

0 commit comments

Comments
 (0)