@@ -28,6 +28,7 @@ def rename_for_nnx(key):
2828 new_key = key [:- 1 ] + ("scale" ,)
2929 return new_key
3030
31+
3132def rename_for_custom_trasformer (key ):
3233 renamed_pt_key = key .replace ("model.diffusion_model." , "" )
3334
@@ -53,6 +54,7 @@ def rename_for_custom_trasformer(key):
5354
5455 return renamed_pt_key
5556
57+
5658def load_fusionx_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
5759 device = jax .devices (device )[0 ]
5860 with jax .default_device (device ):
@@ -74,7 +76,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
7476 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
7577 for pt_key , tensor in tensors .items ():
7678 renamed_pt_key = rename_key (pt_key )
77-
79+
7880 renamed_pt_key = rename_for_custom_trasformer (renamed_pt_key )
7981
8082 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
@@ -89,6 +91,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8991 jax .clear_caches ()
9092 return flax_state_dict
9193
94+
9295def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
9396 device = jax .devices (device )[0 ]
9497 with jax .default_device (device ):
0 commit comments