Skip to content

Commit 65ca402

Browse files
committed
wip head dim 256
1 parent 4e362c3 commit 65ca402

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ def _reshape_heads_to_head_dim(tensor):
104104
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
105105

106106

107-
def _unflatten_heads(tensor, heads):
107+
def _unflatten_heads(tensor, heads, divisor=1):
108108
# reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
109109
batch, seq, heads_and_dim_head = tensor.shape
110-
tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads)
110+
tensor = tensor.reshape(batch, seq, heads // divisor, divisor * heads_and_dim_head // heads)
111111
# Transpose to ('batch', 'heads', 'length', 'kv')
112112
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
113113
return tensor
@@ -120,7 +120,7 @@ def _reshape_data_for_flash(tensor, heads):
120120
blocks is divisible by the number of shards.
121121
"""
122122
if tensor.ndim != 4:
123-
tensor = _unflatten_heads(tensor, heads)
123+
tensor = _unflatten_heads(tensor, heads, divisor=2)
124124
return tensor
125125

126126

0 commit comments

Comments
 (0)