We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b5526e7 commit d8d81d0Copy full SHA for d8d81d0
1 file changed
check_learnable_registers.py
@@ -0,0 +1,18 @@
1
+import os
2
+os.environ["JAX_PLATFORMS"] = "cpu"
3
+import jax
4
+from flax import nnx
5
+from flax.traverse_util import flatten_dict
6
+from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
7
+
8
+def main():
9
+ rngs = nnx.Rngs(0)
10
+ encoder = LTX2AudioVideoGemmaTextEncoder(rngs=rngs)
11
+ _, state = nnx.split(encoder)
12
+ flat_state = flatten_dict(state)
13
+ for k in flat_state.keys():
14
+ if "learnable_registers" in k:
15
+ print(k)
16
17
+if __name__ == "__main__":
18
+ main()
0 commit comments