@@ -83,9 +83,20 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8383
8484 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
8585
86+ if "blocks" in pt_tuple_key :
87+ new_key = ("blocks" ,) + pt_tuple_key [2 :]
88+ block_index = int (pt_tuple_key [1 ])
89+ pt_tuple_key = new_key
8690 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , model_type = WAN_MODEL )
8791 flax_key = rename_for_nnx (flax_key )
8892 flax_key = _tuple_str_to_int (flax_key )
93+
94+ if "blocks" in flax_key :
95+ if flax_key in flax_state_dict :
96+ new_tensor = flax_state_dict [flax_key ]
97+ else :
98+ new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
99+ flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
89100 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
90101 validate_flax_state_dict (eval_shapes , flax_state_dict )
91102 flax_state_dict = unflatten_dict (flax_state_dict )
@@ -118,9 +129,21 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
118129
119130 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
120131
132+ if "blocks" in pt_tuple_key :
133+ new_key = ("blocks" ,) + pt_tuple_key [2 :]
134+ block_index = int (pt_tuple_key [1 ])
135+ pt_tuple_key = new_key
121136 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , model_type = WAN_MODEL )
122137 flax_key = rename_for_nnx (flax_key )
123138 flax_key = _tuple_str_to_int (flax_key )
139+
140+
141+ if "blocks" in flax_key :
142+ if flax_key in flax_state_dict :
143+ new_tensor = flax_state_dict [flax_key ]
144+ else :
145+ new_tensor = jnp .zeros ((40 ,) + flax_tensor .shape )
146+ flax_tensor = new_tensor .at [block_index ].set (flax_tensor )
124147 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
125148 validate_flax_state_dict (eval_shapes , flax_state_dict )
126149 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments