@@ -608,54 +608,54 @@ def test_mask_video_preprocessing_matches_diffusers(self):
608608
609609 def test_check_inputs_matches_diffusers_validation (self ):
610610 invalid_calls = [
611- dict (
612- prompt = "prompt" ,
613- negative_prompt = None ,
614- image = PIL .Image .new ("RGB" , (16 , 16 )),
615- pose_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
616- face_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
617- background_video = None ,
618- mask_video = None ,
619- height = 16 ,
620- width = 16 ,
621- prompt_embeds = jnp .zeros ((1 , 1 , 1 )),
622- negative_prompt_embeds = None ,
623- image_embeds = None ,
624- mode = "animate" ,
625- prev_segment_conditioning_frames = 1 ,
626- ) ,
627- dict (
628- prompt = "prompt" ,
629- negative_prompt = None ,
630- image = PIL .Image .new ("RGB" , (16 , 16 )),
631- pose_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
632- face_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
633- background_video = None ,
634- mask_video = None ,
635- height = 18 ,
636- width = 16 ,
637- prompt_embeds = None ,
638- negative_prompt_embeds = None ,
639- image_embeds = None ,
640- mode = "animate" ,
641- prev_segment_conditioning_frames = 1 ,
642- ) ,
643- dict (
644- prompt = "prompt" ,
645- negative_prompt = None ,
646- image = PIL .Image .new ("RGB" , (16 , 16 )),
647- pose_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
648- face_video = [PIL .Image .new ("RGB" , (16 , 16 ))],
649- background_video = None ,
650- mask_video = None ,
651- height = 16 ,
652- width = 16 ,
653- prompt_embeds = None ,
654- negative_prompt_embeds = None ,
655- image_embeds = None ,
656- mode = "replace" ,
657- prev_segment_conditioning_frames = 3 ,
658- ) ,
611+ {
612+ " prompt" : "prompt" ,
613+ " negative_prompt" : None ,
614+ " image" : PIL .Image .new ("RGB" , (16 , 16 )),
615+ " pose_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
616+ " face_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
617+ " background_video" : None ,
618+ " mask_video" : None ,
619+ " height" : 16 ,
620+ " width" : 16 ,
621+ " prompt_embeds" : jnp .zeros ((1 , 1 , 1 )),
622+ " negative_prompt_embeds" : None ,
623+ " image_embeds" : None ,
624+ " mode" : "animate" ,
625+ " prev_segment_conditioning_frames" : 1 ,
626+ } ,
627+ {
628+ " prompt" : "prompt" ,
629+ " negative_prompt" : None ,
630+ " image" : PIL .Image .new ("RGB" , (16 , 16 )),
631+ " pose_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
632+ " face_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
633+ " background_video" : None ,
634+ " mask_video" : None ,
635+ " height" : 18 ,
636+ " width" : 16 ,
637+ " prompt_embeds" : None ,
638+ " negative_prompt_embeds" : None ,
639+ " image_embeds" : None ,
640+ " mode" : "animate" ,
641+ " prev_segment_conditioning_frames" : 1 ,
642+ } ,
643+ {
644+ " prompt" : "prompt" ,
645+ " negative_prompt" : None ,
646+ " image" : PIL .Image .new ("RGB" , (16 , 16 )),
647+ " pose_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
648+ " face_video" : [PIL .Image .new ("RGB" , (16 , 16 ))],
649+ " background_video" : None ,
650+ " mask_video" : None ,
651+ " height" : 16 ,
652+ " width" : 16 ,
653+ " prompt_embeds" : None ,
654+ " negative_prompt_embeds" : None ,
655+ " image_embeds" : None ,
656+ " mode" : "replace" ,
657+ " prev_segment_conditioning_frames" : 3 ,
658+ } ,
659659 ]
660660
661661 for kwargs in invalid_calls :
@@ -780,7 +780,7 @@ def _scalar(x):
780780 hf_negative = torch .tensor (to_numpy (max_negative ))
781781 hf_image = torch .tensor (to_numpy (max_image ))
782782
783- scheduler_config = dict ( prediction_type = " flow_prediction" , use_flow_sigmas = True , flow_shift = 5.0 )
783+ scheduler_config = { " prediction_type" : " flow_prediction" , " use_flow_sigmas" : True , " flow_shift" : 5.0 }
784784 max_scheduler = FlaxUniPCMultistepScheduler (** scheduler_config )
785785 max_state = max_scheduler .create_state ()
786786 max_state = max_scheduler .set_timesteps (max_state , num_inference_steps = timestep_count , shape = max_latents .shape )
@@ -852,7 +852,7 @@ def _scalar(x):
852852 np .testing .assert_allclose (to_numpy (max_next ), hf_channel_first_to_last (hf_next ), atol = 1e-5 , rtol = 1e-5 )
853853
854854 def test_flax_unipc_flow_sigmas_match_diffusers (self ):
855- scheduler_config = dict ( prediction_type = " flow_prediction" , use_flow_sigmas = True , flow_shift = 5.0 )
855+ scheduler_config = { " prediction_type" : " flow_prediction" , " use_flow_sigmas" : True , " flow_shift" : 5.0 }
856856
857857 max_scheduler = FlaxUniPCMultistepScheduler (** scheduler_config )
858858 max_state = max_scheduler .create_state ()
0 commit comments