Skip to content

Commit b11767b

Browse files
improve attention for gpus by using pmap.
1 parent 6f536d8 commit b11767b

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,21 @@ def cudnn_flash_attention(
180180
key = nn.with_logical_constraint(key, axis_names)
181181
value = nn.with_logical_constraint(value, axis_names)
182182

183-
out = self.dpa_layer(query, key, value, mask=None)
183+
@functools.partial(
184+
shard_map.shard_map,
185+
mesh=self.mesh,
186+
in_specs=(
187+
axis_names,
188+
axis_names,
189+
axis_names
190+
),
191+
out_specs=axis_names,
192+
check_rep=False
193+
)
194+
def wrap_flash_attention(query, key, value):
195+
return jax.vmap(self.dpa_layer)(query, key, value, mask=None)
196+
197+
out = wrap_flash_attention(query, key, value)#self.dpa_layer(query, key, value, mask=None)
184198
return self.reshape_data_from_cudnn_flash(out)
185199

186200
def apply_attention_dot(self, query: Array, key: Array, value: Array):

0 commit comments

Comments
 (0)