@@ -114,19 +114,20 @@ def _load_kwargs(self, argv: list[str]):
114114
115115 @staticmethod
116116 def wan_init (raw_keys ):
117- transformer_pretrained_model_name_or_path = raw_keys ["transformer_pretrained_model_name_or_path" ]
118- if transformer_pretrained_model_name_or_path == "" :
119- raw_keys ["transformer_pretrained_model_name_or_path" ] = raw_keys ["pretrained_model_name_or_path" ]
120- elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
121- # Set correct parameters for CausVid in case of user error.
122- raw_keys ["guidance_scale" ] = 1.0
123- num_inference_steps = raw_keys ["num_inference_steps" ]
124- if num_inference_steps > 10 :
125- max_logging .log (
126- f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting { num_inference_steps } steps."
127- )
128- else :
129- raise ValueError (f"{ transformer_pretrained_model_name_or_path } transformer model is not supported for Wan 2.1" )
117+ if "wan_transformer_pretrained_model_name_or_path" in raw_keys :
118+ transformer_pretrained_model_name_or_path = raw_keys ["wan_transformer_pretrained_model_name_or_path" ]
119+ if transformer_pretrained_model_name_or_path == "" :
120+ raw_keys ["wan_transformer_pretrained_model_name_or_path" ] = raw_keys ["pretrained_model_name_or_path" ]
121+ elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
122+ # Set correct parameters for CausVid in case of user error.
123+ raw_keys ["guidance_scale" ] = 1.0
124+ num_inference_steps = raw_keys ["num_inference_steps" ]
125+ if num_inference_steps > 10 :
126+ max_logging .log (
127+ f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting { num_inference_steps } steps."
128+ )
129+ else :
130+ raise ValueError (f"{ transformer_pretrained_model_name_or_path } transformer model is not supported for Wan 2.1" )
130131
131132 @staticmethod
132133 def user_init (raw_keys ):
0 commit comments