Skip to content

Commit a581e00

Browse files
Merge pull request #2812 from CIeNET-International:fix/Gpt3-projection-dimension
PiperOrigin-RevId: 843367659
2 parents fdcb3c9 + ebef575 commit a581e00

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/MaxText/layers/gpt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
self.key = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv"))
247247
self.value = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv"))
248248
self.out = self.create_projection_layer(
249-
(self.num_heads, self.head_dim), self.num_heads * self.head_dim, ("heads", "kv", "embed"), axis=(-2, -1)
249+
(self.num_heads, self.head_dim), feature_dim[-1], ("heads", "kv", "embed"), axis=(-2, -1)
250250
)
251251
self.attention_op = AttentionOp(
252252
config=config,

0 commit comments

Comments
 (0)