@@ -109,6 +109,8 @@ def set_timesteps(
109109 denoising_strength : float = 1.0 ,
110110 training : bool = False ,
111111 shift : Optional [float ] = None ,
112+ timesteps : Optional [jnp .ndarray ] = None ,
113+ sigmas : Optional [jnp .ndarray ] = None ,
112114 ) -> FlowMatchSchedulerState :
113115 """
114116 Sets the discrete timesteps used for the diffusion chain.
@@ -126,27 +128,42 @@ def set_timesteps(
126128 Whether the scheduler is being used for training.
127129 shift (`Optional[float]`):
128130 An optional shift value to override the one in the config.
131+ timesteps (`Optional[jnp.ndarray]`):
132+ Custom timesteps to use for the denoising process.
133+ sigmas (`Optional[jnp.ndarray]`):
134+ Custom sigmas to use for the denoising process.
129135
130136 Returns:
131137 `FlowMatchSchedulerState`: The updated scheduler state.
132138 """
133139 current_shift = shift if shift is not None else self .config .shift
134- sigma_start = self .config .sigma_min + (self .config .sigma_max - self .config .sigma_min ) * denoising_strength
135140
136- if self .config .extra_one_step :
137- sigmas = jnp .linspace (sigma_start , self .config .sigma_min , num_inference_steps + 1 , dtype = self .dtype )[:- 1 ]
141+ if timesteps is not None and sigmas is not None :
142+ pass
143+ elif timesteps is not None :
144+ sigmas = timesteps / self .config .num_train_timesteps
145+ elif sigmas is not None :
146+ timesteps = sigmas * self .config .num_train_timesteps
138147 else :
139- sigmas = jnp . linspace ( sigma_start , self .config .sigma_min , num_inference_steps , dtype = self .dtype )
148+ sigma_start = self . config . sigma_min + ( self .config .sigma_max - self .config . sigma_min ) * denoising_strength
140149
141- if self .config .inverse_timesteps :
142- sigmas = jnp .flip (sigmas , dims = [0 ])
150+ if self .config .extra_one_step :
151+ sigmas = jnp .linspace (sigma_start , self .config .sigma_min , num_inference_steps + 1 , dtype = self .dtype )[:- 1 ]
152+ else :
153+ sigmas = jnp .linspace (sigma_start , self .config .sigma_min , num_inference_steps , dtype = self .dtype )
143154
144- sigmas = current_shift * sigmas / (1 + (current_shift - 1 ) * sigmas )
155+ if self .config .inverse_timesteps :
156+ sigmas = jnp .flip (sigmas , dims = [0 ])
145157
146- if self .config .reverse_sigmas :
147- sigmas = 1 - sigmas
158+ sigmas = current_shift * sigmas / (1 + (current_shift - 1 ) * sigmas )
148159
149- timesteps = sigmas * self .config .num_train_timesteps
160+ if self .config .reverse_sigmas :
161+ sigmas = 1 - sigmas
162+
163+ timesteps = sigmas * self .config .num_train_timesteps
164+
165+ if timesteps is not None :
166+ num_inference_steps = len (timesteps )
150167
151168 linear_timesteps_weights = None
152169 if training :
0 commit comments