Skip to content

Commit a98bee6

Browse files
committed
check learnable registers
1 parent d8d81d0 commit a98bee6

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

check_learnable_registers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ def main():
99
rngs = nnx.Rngs(0)
1010
encoder = LTX2AudioVideoGemmaTextEncoder(rngs=rngs)
1111
_, state = nnx.split(encoder)
12-
flat_state = flatten_dict(state)
12+
13+
# Convert nnx State to dict
14+
flat_state = flatten_dict(state.to_pure_dict())
15+
1316
for k in flat_state.keys():
1417
if "learnable_registers" in k:
1518
print(k)

0 commit comments

Comments
 (0)