Skip to content

Commit d8d81d0

Browse files
committed
check learnable registers
1 parent b5526e7 commit d8d81d0

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

check_learnable_registers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
os.environ["JAX_PLATFORMS"] = "cpu"
3+
import jax
4+
from flax import nnx
5+
from flax.traverse_util import flatten_dict
6+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
7+
8+
def main():
9+
rngs = nnx.Rngs(0)
10+
encoder = LTX2AudioVideoGemmaTextEncoder(rngs=rngs)
11+
_, state = nnx.split(encoder)
12+
flat_state = flatten_dict(state)
13+
for k in flat_state.keys():
14+
if "learnable_registers" in k:
15+
print(k)
16+
17+
if __name__ == "__main__":
18+
main()

0 commit comments

Comments
 (0)