Commit 060fecf
committed
Initial commit for the optimized gdn impl
refactor comments
replaced gdn to align w/ pytorch gdn
Add kl div test
Use seq_len instead of S
Fixed transpose error and removed testing functions
Keep naive gated delta rule impl for ref
refactor comments on GDN functions
Update tflops calc to align w/ megatron & update dtype for GDN
Fix spacing issues in linter
Run precommit for pylint fix
Convert lambda func to explicit function using def
Add megatron ref for gdn tflops
ran pyink formatter1 parent fa4e1e7 commit 060fecf
3 files changed
Lines changed: 239 additions & 170 deletions
File tree
- src
- MaxText/layers
- maxtext/utils
- tests/unit
0 commit comments