Skip to content

Commit b7b7727

Browse files
committed
wip
1 parent 65ca402 commit b7b7727

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def _reshape_data_for_flash(tensor, heads):
120120
blocks is divisible by the number of shards.
121121
"""
122122
if tensor.ndim != 4:
123-
tensor = _unflatten_heads(tensor, heads, divisor=2)
123+
tensor = _unflatten_heads(tensor, heads)
124+
else:
125+
b, h, s, d = tensor.shape
126+
if d != 256:
127+
tensor = tensor.reshape(b, h//2, s, d*2)
124128
return tensor
125129

126130

0 commit comments

Comments
 (0)