Skip to content

Commit 5f2434d

Browse files
fix attention bug for lower frames.
1 parent 5cc2e49 commit 5f2434d

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,16 @@ def _reshape_batch_dim_to_heads(tensor, heads):
7575
return tensor
7676

7777
def _reshape_heads_to_batch_dim(tensor, heads):
78-
batch_size, seq_len, dim = tensor.shape
79-
head_size = heads
80-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81-
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
82-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
78+
if tensor.ndim == 3:
79+
batch_size, seq_len, dim = tensor.shape
80+
head_size = heads
81+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
82+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
83+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
84+
else:
85+
batch_size, head_size, seq_len, head_dim = tensor.shape
86+
tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)
87+
8388
return tensor
8489

8590
def _reshape_heads_to_head_dim(tensor):

0 commit comments

Comments
 (0)