Skip to content

Commit ebef575

Browse files
Fix Gpt3 MultiHeadAttention out projection dimension
1 parent 9e16c99 commit ebef575

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/MaxText/layers/gpt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
self.key = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv"))
247247
self.value = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv"))
248248
self.out = self.create_projection_layer(
249-
(self.num_heads, self.head_dim), self.num_heads * self.head_dim, ("heads", "kv", "embed"), axis=(-2, -1)
249+
(self.num_heads, self.head_dim), feature_dim[-1], ("heads", "kv", "embed"), axis=(-2, -1)
250250
)
251251
self.attention_op = AttentionOp(
252252
config=config,

0 commit comments

Comments
 (0)