Skip to content

Commit 303c4f2

Browse files
committed
fix
1 parent 64cf8a4 commit 303c4f2

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

reproduce_key_mapping.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@
4848
print(f"Tuple Key: {pt_tuple_key}")
4949

5050
# 3. get_key_and_value
51-
# We need dummy tensor
52-
dummy_tensor = torch.zeros((10, 10))
51+
# We need dummy tensor (JAX array)
52+
import jax.numpy as jnp
53+
dummy_tensor = jnp.zeros((128, 128), dtype=jnp.float32)
5354
flax_state_dict = {} # Mock
5455

5556
# Need to simulate scan_layers=True
5657
scan_layers = True
57-
num_layers = 48
58+
num_layers = 1 # Use 1 for debug
5859

5960
flax_key, flax_tensor = get_key_and_value(
6061
pt_tuple_key, dummy_tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers

0 commit comments

Comments
 (0)