We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5cc2e49 commit 5f2434dCopy full SHA for 5f2434d
1 file changed
src/maxdiffusion/models/attention_flax.py
@@ -75,11 +75,16 @@ def _reshape_batch_dim_to_heads(tensor, heads):
75
return tensor
76
77
def _reshape_heads_to_batch_dim(tensor, heads):
78
- batch_size, seq_len, dim = tensor.shape
79
- head_size = heads
80
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
82
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ head_size = heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
83
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
84
+ else:
85
+ batch_size, head_size, seq_len, head_dim = tensor.shape
86
+ tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)
87
+
88
89
90
def _reshape_heads_to_head_dim(tensor):
0 commit comments