@@ -1377,8 +1377,10 @@ def _sched_cfg_get(key: str, default):
13771377
13781378 do_cfg = guidance_scale > 1.0
13791379 do_stg = stg_scale > 0.0
1380+ print (f"DEBUG: do_cfg={ do_cfg } , do_stg={ do_stg } , guidance_scale={ guidance_scale } , stg_scale={ stg_scale } " )
13801381
13811382 if do_cfg and do_stg :
1383+ print ("DEBUG: Pipeline: Branching into do_cfg AND do_stg" )
13821384 negative_prompt_embeds_jax = negative_prompt_embeds
13831385 negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13841386
@@ -1395,6 +1397,7 @@ def _sched_cfg_get(key: str, default):
13951397 perturbation_mask = jnp .concatenate ([jnp .ones ((2 * N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype ), jnp .ones ((N , 1 , 1 ), dtype = dtype )], axis = 0 )
13961398
13971399 elif do_cfg :
1400+ print ("DEBUG: Pipeline: Branching into do_cfg only" )
13981401 negative_prompt_embeds_jax = negative_prompt_embeds
13991402 negative_prompt_attention_mask_jax = negative_prompt_attention_mask
14001403 if isinstance (prompt_embeds_jax , list ):
@@ -1408,6 +1411,7 @@ def _sched_cfg_get(key: str, default):
14081411 perturbation_mask = None
14091412
14101413 elif do_stg :
1414+ print ("DEBUG: Pipeline: Branching into do_stg only" )
14111415 if isinstance (prompt_embeds_jax , list ):
14121416 prompt_embeds_jax = [jnp .concatenate ([p , p ], axis = 0 ) for p in prompt_embeds_jax ]
14131417 else :
@@ -1420,6 +1424,7 @@ def _sched_cfg_get(key: str, default):
14201424 N = latents .shape [0 ]
14211425 perturbation_mask = jnp .concatenate ([jnp .ones ((N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype )], axis = 0 )
14221426 else :
1427+ print ("DEBUG: Pipeline: No guidance branch (Standard path)" )
14231428 perturbation_mask = None
14241429
14251430 if hasattr (self , "mesh" ) and self .mesh is not None :
0 commit comments