From 63b318b30b79d60f77a885f6b5ab9fa941a63ef8 Mon Sep 17 00:00:00 2001 From: Quinn Date: Tue, 13 May 2025 22:13:35 +0000 Subject: [PATCH 1/4] Add the unipc multistep scheduler. --- .../scheduling_unipc_multistep_flax.py | 817 ++++++++++++++++++ 1 file changed, 817 insertions(+) create mode 100644 src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py new file mode 100644 index 000000000..f0733ca3f --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -0,0 +1,817 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: reference pytorch implementation: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_unipc_multistep.py + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_scipy_available +from .scheduling_utils_flax import ( + CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +@flax.struct.dataclass +class UniPCMultistepSchedulerState: + """ + Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. + """ + + common: CommonSchedulerState + + # Core schedule parameters (derived from CommonSchedulerState in create_state) + sigmas: jnp.ndarray + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + init_noise_sigma: float + + # History buffers for multi-step solver + # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) + timesteps: jnp.ndarray = None + model_outputs: jnp.ndarray = None + timestep_list: jnp.ndarray = ( + None # Stores corresponding timesteps for `model_outputs` + ) + + # State variables for tracking progress and solver order + lower_order_nums: int = 0 + last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step + step_index: Optional[int] = None + begin_index: Optional[int] = None # Used for img2img/inpaing + this_order: int = 0 # Current effective order of the UniPC solver for this step + + @classmethod + def create( + cls, + common_state: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + sigmas: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + ): + return cls( + common=common_state, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + lower_order_nums=0, + last_sample=None, + step_index=None, + begin_index=None, + this_order=0, + ) + + +@dataclass +class FlaxUniPCMultistepSchedulerOutput(FlaxSchedulerOutput): + state: UniPCMultistepSchedulerState + + +class FlaxUniPCMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. + It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: Optional[FlaxSchedulerMixin] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + use_flow_sigmas: Optional[bool] = False, + flow_shift: Optional[float] = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", + rescale_zero_terminal_snr: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + # Validation checks from original __init__ + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if self.config.solver_type not in ["bh1", "bh2"]: + raise NotImplementedError( + f"{self.config.solver_type} is not implemented for {self.__class__}" + ) + + def create_state( + self, common: Optional[CommonSchedulerState] = None + ) -> UniPCMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + if self.config.solver_type not in ["bh1", "bh2"]: + if self.config.solver_type in ["midpoint", "heun", "logrho"]: + self.config.solver_type = "bh2" + else: + raise NotImplementedError( + f"{self.config.solver_type} is not implemented for {self.__class__}" + ) + + return UniPCMultistepSchedulerState.create( + common_state=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + ) + + def set_begin_index( + self, state: UniPCMultistepSchedulerState, begin_index: int = 0 + ) -> UniPCMultistepSchedulerState: + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + """ + return state.replace(begin_index=begin_index) + + def set_timesteps( + self, + state: UniPCMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, + ) -> UniPCMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + """ + #### Copied from scheduling_dmpsolver_multistep_flax + last_timestep = self.config.num_train_timesteps + if self.config.timestep_spacing == "linspace": + timesteps = ( + jnp.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .astype(jnp.int32) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (jnp.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(jnp.int32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + jnp.arange(last_timestep, 0, -step_ratio) + .round() + .copy() + .astype(jnp.int32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + # initial running values + sigmas = state.sigmas + + # TODO + # # Apply Karras/Exponential/Beta/Flow Sigmas if configured + # if self.config.use_karras_sigmas: + # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + # elif self.config.use_exponential_sigmas: + # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + # elif self.config.use_beta_sigmas: + # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + if self.config.use_flow_sigmas: + alphas = jnp.linspace( + 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 + ) + sigmas = 1.0 - alphas + sigmas = jnp.flip( + self.config.flow_shift + * sigmas + / (1 + (self.config.flow_shift - 1) * sigmas) + )[:-1].copy() + timesteps = ( + (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) + ) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( + jnp.float32 + ) + else: # Default case if none of the specialized sigmas are used + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ( + (1 - state.common.alphas_cumprod[0]) + / state.common.alphas_cumprod[0] + ) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype( + jnp.float32 + ) + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + timestep_list = jnp.zeros( + (self.config.solver_order,), dtype=jnp.int32 # Timesteps are integers + ) + # Update the state with the new schedule and re-initialized history + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + model_outputs=model_outputs, + timestep_list=timestep_list, + lower_order_nums=0, # Reset counters for a new inference run + step_index=None, + begin_index=None, + last_sample=None, + this_order=0, + ) + + def convert_model_output( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Converts the model output based on the prediction type and current state. + """ + sigma = state.sigmas[state.step_index] # Current sigma + + # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.config.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + # Original code has `sigma_t = self.sigmas[self.step_index]`. + # This implies current sigma `sigma` is used as sigma_t for flow. + x0_pred = sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + raise ValueError(f"Dynamic thresholding isn't implemented.") + # x0_pred = self._threshold_sample(x0_pred) + return x0_pred + else: # self.config.predict_x0 is False + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, # Original model output from the diffusion model + sample: jnp.ndarray, + order: int, + ) -> jnp.ndarray: + """ + One step for the UniP (B(h) version) - the Predictor. + """ + if self.config.solver_p: + raise NotImplementedError( + "Nested `solver_p` is not implemented in JAX version yet." + ) + + m0 = state.model_outputs[ + self.config.solver_order - 1 + ] # Most recent stored converted model output + x = sample + + sigma_t_val, sigma_s0_val = ( + state.sigmas[state.step_index + 1], + state.sigmas[state.step_index], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + rks_list = [] + D1s_list = [] + + for i in range(1, order): + history_idx = ( + self.config.solver_order - 1 - i + ) # Correct index for history array + + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[ + history_idx + ] # This is the actual timestep value + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( + state.sigmas[self.index_for_timestep(state, si_val)] + ) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + rks_list.append(rk) + D1s_list.append((mi - m0) / rk) + + rks_list.append(1.0) # Append the last 1.0 for r_order + rks = jnp.stack(rks_list) # Shape (order,) + + R_list = [] + b_list = [] + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + current_h_phi_k = h_phi_1 / hh - 1.0 + factorial_val = 1.0 # factorial(1) is 1. For `factorial_i *= i + 1` + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): # Loop from i=1 to order + R_list.append(jnp.power(rks, i - 1)) + b_list.append(current_h_phi_k * factorial_val / B_h) + + if i < order: # Update for next iteration (i+1) + factorial_val *= i + 1 + current_h_phi_k = ( + current_h_phi_k / hh - 1.0 / factorial_val + ) # Update for next i + + R = jnp.stack(R_list) # Shape (order, order) + b = jnp.stack(b_list) # Shape (order,) + + D1s = None + if len(D1s_list) > 0: + D1s = jnp.stack(D1s_list, axis=1) # Resulting shape (B, K, C, H, W) + + if order == 2: # Special case for order 2 from original + rhos_p = jnp.array([0.5], dtype=x.dtype) + else: # General case, solve linear system + + rhos_p = jnp.linalg.solve(R[:-1, :-1], b[:-1]).astype(x.dtype) + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + # einsum `k,bkc...->bc...` where k is rhos_p dim, b is batch, c is channel, ... + pred_res = jnp.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0.0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: # Predict epsilon + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = jnp.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0.0 + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.astype(x.dtype) + + def multistep_uni_c_bh_update( + self, + state: UniPCMultistepSchedulerState, + this_model_output: jnp.ndarray, + last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` + this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) + order: int, + ) -> jnp.ndarray: + """ + One step for the UniC (B(h) version) - the Corrector. + """ + model_output_list = state.model_outputs # History buffer + m0 = model_output_list[ + self.config.solver_order - 1 + ] # Most recent model output from history + + x = last_sample # Sample after predictor (`x_{t-1}`) + x_t = this_sample # Sample after predictor (`x_t`) + model_t = this_model_output # The new model output evaluated at `x_t` + + sigma_t_val = state.sigmas[state.step_index] + sigma_s0_val = state.sigmas[ + state.step_index - 1 + ] # This is the sigma corresponding to `x` (last_sample) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + rks_list = [] + D1s_list = [] + + for i in range(1, order): + history_idx = self.config.solver_order - ( + i + 1 + ) # Index in the fixed-size history array + + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[ + history_idx + ] # This is the actual timestep value + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t( + state.sigmas[self.index_for_timestep(state, si_val)] + ) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + rks_list.append(rk) + D1s_list.append((mi - m0) / rk) + + rks_list.append(1.0) + rks = jnp.stack(rks_list) + + R_list = [] + b_list = [] + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + # Calculate h_phi_k values for coefficients + current_h_phi_k = h_phi_1 / hh - 1.0 # Initial value for i=1 + factorial_val = 1.0 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R_list.append(jnp.power(rks, i - 1)) + b_list.append(current_h_phi_k * factorial_val / B_h) + + if i < order: + factorial_val *= i + 1 + current_h_phi_k = current_h_phi_k / hh - 1.0 / factorial_val + + R = jnp.stack(R_list) + b = jnp.stack(b_list) + + D1s = None + if len(D1s_list) > 0: + D1s = jnp.stack(D1s_list, axis=1) # (B, K, C, H, W) + + if order == 1: + rhos_c = jnp.array([0.5], dtype=x.dtype) + else: + rhos_c = jnp.linalg.solve(R, b).astype( + x.dtype + ) # Use all of R and b for corrector + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + # einsum `k,bkc...->bc...` where k is rhos_c[:-1] dim + corr_res = jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0.0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0.0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + + return x_t.astype(x.dtype) + + def index_for_timestep( + self, + state: UniPCMultistepSchedulerState, + timestep: Union[int, jnp.ndarray], + schedule_timesteps: Optional[jnp.ndarray] = None, + ) -> int: + """ "Gets the step_index for timestep.""" + if schedule_timesteps is None: + schedule_timesteps = state.timesteps + + timestep_val = ( + timestep.item() + if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 + else timestep + ) + + index_candidates = jnp.where( + schedule_timesteps == timestep_val, size=1, fill_value=-1 + )[0] + + if index_candidates[0] == -1: # No match found + step_index = len(schedule_timesteps) - 1 # Default to last index + elif len(index_candidates) > 1: + step_index = index_candidates[ + 1 + ].item() # Take the second match (diffusers behavior) + else: + step_index = index_candidates[0].item() # Take the first (and only) match + return step_index + + def _init_step_index( + self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] + ) -> UniPCMultistepSchedulerState: + """Initializes the step_index counter for the scheduler.""" + if state.begin_index is None: + step_index_val = self.index_for_timestep(state, timestep) + return state.replace(step_index=step_index_val) + else: + return state.replace(step_index=state.begin_index) + + def step( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) + timestep: Union[ + int, jnp.ndarray + ], # Current discrete timestep from the scheduler's sequence + sample: jnp.ndarray, # Current noisy sample (latent) + return_dict: bool = True, + generator: Optional[jax.random.PRNGKey] = None, # JAX random key + ) -> Union[FlaxUniPCMultistepSchedulerOutput, Tuple[jnp.ndarray]]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + """ + if state.timesteps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + # Ensure timestep is a scalar for indexing/comparison + timestep_scalar = ( + timestep.item() + if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 + else int(timestep) + ) # Ensure int type + + # Initialize step_index if it's the first step + if state.step_index is None: + state = self._init_step_index(state, timestep_scalar) + + # Determine if corrector should be used + use_corrector = ( + state.step_index > 0 + and state.step_index - 1 not in self.config.disable_corrector + and state.last_sample + is not None # `last_sample` stores output of previous predictor + ) + + # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type + model_output_for_history = self.convert_model_output( + state, model_output, sample + ) + + # Apply corrector if applicable + if use_corrector: + corrected_sample = self.multistep_uni_c_bh_update( + state=state, + this_model_output=model_output_for_history, + last_sample=state.last_sample, + this_sample=sample, + order=state.this_order, + ) + sample = corrected_sample + + # Update history buffers (model_outputs and timestep_list) + # Shift existing elements to the left and add new one at the end. + # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. + # Example: + # t0:[None,...,model_output0] + # t1:[None,..model_output0,model_output1] + # ... + # tn:[model_output0,model_output1,...,model_output_n] + if state.step_index == 0: + updated_model_outputs_history = state.model_outputs.at[-1].set( + model_output_for_history + ) + updated_timestep_list_history = state.timestep_list.at[-1].set( + timestep_scalar + ) + else: + updated_model_outputs_history = jnp.roll( + state.model_outputs, shift=-1, axis=0 + ) + updated_model_outputs_history = updated_model_outputs_history.at[-1].set( + model_output_for_history + ) + + updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) + updated_timestep_list_history = updated_timestep_list_history.at[-1].set( + timestep_scalar + ) + + state = state.replace( + model_outputs=updated_model_outputs_history, + timestep_list=updated_timestep_list_history, + ) + + # Determine the order for the current step (warmup phase logic) + if self.config.lower_order_final: + this_order = jnp.minimum( + self.config.solver_order, len(state.timesteps) - state.step_index + ) + else: + this_order = self.config.solver_order + + # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` + new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) + state = state.replace(this_order=new_this_order) + + # Ensure `this_order` is positive, should always be. + assert new_this_order > 0, "Solver order must be positive." + + # Store current sample as `last_sample` for the *next* step's corrector + state = state.replace(last_sample=sample) + + # UniP predictor step + prev_sample = self.multistep_uni_p_bh_update( + state=state, + model_output=model_output, + sample=sample, + order=state.this_order, + ) + + # Update lower_order_nums for warmup + if state.lower_order_nums < self.config.solver_order: + state = state.replace(lower_order_nums=state.lower_order_nums + 1) + + # Upon completion, increase step index by one + state = state.replace(step_index=state.step_index + 1) + + # Return the updated sample and state + if not return_dict: + return (prev_sample, state) + + return FlaxUniPCMultistepSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input( + self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs + ) -> jnp.ndarray: + """ + UniPC does not scale model input, so it returns the sample unchanged. + """ + return sample + + def add_noise( + self, + state: UniPCMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, # JAX array for timesteps (scalar or 0-dim array) + ) -> jnp.ndarray: + """ + Adds noise to `original_samples` based on the provided `timesteps` and `noise`. + """ + if timesteps.ndim > 0: + raise ValueError( + f"For `add_noise`, `timesteps` must be a scalar (0-dim JAX array), but got shape {timesteps.shape}." + ) + + # Get `step_indices` for the current timestep. + if state.begin_index is None: + step_idx = self.index_for_timestep(state, timesteps.item(), state.timesteps) + step_indices = jnp.array( + [step_idx] + ) # Make it a 1-element array for indexing sigmas + elif state.step_index is not None: + step_indices = jnp.array([state.step_index]) + else: + step_indices = jnp.array([state.begin_index]) + + sigma = state.sigmas[step_indices].flatten() + sigma = broadcast_to_shape_from_left(sigma, noise.shape) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def __len__(self) -> int: + return self.config.num_train_timesteps From 5cd45956b1b10d896f99b24434ad2f7d86b0876c Mon Sep 17 00:00:00 2001 From: Quinn Date: Thu, 15 May 2025 18:57:22 +0000 Subject: [PATCH 2/4] Add unit tests to the unipc multistep scheduler. --- .../scheduling_unipc_multistep_flax.py | 67 +- tests/schedulers/test_scheduler_unipc.py | 680 ++++++++++++++++++ 2 files changed, 709 insertions(+), 38 deletions(-) create mode 100644 tests/schedulers/test_scheduler_unipc.py diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index f0733ca3f..d80ddb3aa 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -25,10 +25,9 @@ from ..utils import is_scipy_available from .scheduling_utils_flax import ( CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, - broadcast_to_shape_from_left, + add_noise_common, ) @@ -164,6 +163,13 @@ def create_state( if common is None: common = CommonSchedulerState.create(self) + if self.config.get("rescale_zero_terminal_snr", False): + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + alphas_cumprod = common.alphas_cumprod + alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) + common = common.replace(alphas_cumprod=alphas_cumprod) + # Currently we only support VP-type noise schedule alpha_t = jnp.sqrt(common.alphas_cumprod) sigma_t = jnp.sqrt(1 - common.alphas_cumprod) @@ -247,15 +253,24 @@ def set_timesteps( # TODO # # Apply Karras/Exponential/Beta/Flow Sigmas if configured - # if self.config.use_karras_sigmas: - # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - # elif self.config.use_exponential_sigmas: - # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) - # elif self.config.use_beta_sigmas: - # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + if self.config.use_karras_sigmas: + # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError( + "`use_karras_sigmas` is not implemented in JAX version yet." + ) + elif self.config.use_exponential_sigmas: + # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError( + "`use_exponential_sigmas` is not implemented in JAX version yet." + ) + elif self.config.use_beta_sigmas: + # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError( + "`use_beta_sigmas` is not implemented in JAX version yet." + ) if self.config.use_flow_sigmas: alphas = jnp.linspace( 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 @@ -346,7 +361,7 @@ def convert_model_output( ) if self.config.thresholding: - raise ValueError(f"Dynamic thresholding isn't implemented.") + raise NotImplementedError("Dynamic thresholding isn't implemented.") # x0_pred = self._threshold_sample(x0_pred) return x0_pred else: # self.config.predict_x0 is False @@ -775,33 +790,9 @@ def add_noise( state: UniPCMultistepSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, - timesteps: jnp.ndarray, # JAX array for timesteps (scalar or 0-dim array) + timesteps: jnp.ndarray, ) -> jnp.ndarray: - """ - Adds noise to `original_samples` based on the provided `timesteps` and `noise`. - """ - if timesteps.ndim > 0: - raise ValueError( - f"For `add_noise`, `timesteps` must be a scalar (0-dim JAX array), but got shape {timesteps.shape}." - ) - - # Get `step_indices` for the current timestep. - if state.begin_index is None: - step_idx = self.index_for_timestep(state, timesteps.item(), state.timesteps) - step_indices = jnp.array( - [step_idx] - ) # Make it a 1-element array for indexing sigmas - elif state.step_index is not None: - step_indices = jnp.array([state.step_index]) - else: - step_indices = jnp.array([state.begin_index]) - - sigma = state.sigmas[step_indices].flatten() - sigma = broadcast_to_shape_from_left(sigma, noise.shape) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples + return add_noise_common(state.common, original_samples, noise, timesteps) def _sigma_to_alpha_sigma_t(self, sigma): if self.config.use_flow_sigmas: diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py new file mode 100644 index 000000000..559fb546f --- /dev/null +++ b/tests/schedulers/test_scheduler_unipc.py @@ -0,0 +1,680 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/tests/schedulers/test_scheduler_unipc.py + +import tempfile + +import torch +import jax.numpy as jnp +from typing import Dict, List, Tuple + +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import ( + FlaxUniPCMultistepScheduler, +) +from maxdiffusion import FlaxDPMSolverMultistepScheduler + +from .test_scheduler_flax import FlaxSchedulerCommonTest + + +class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxUniPCMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + jax_sample= jnp.asarray(sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample= jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample= jnp.asarray(sample) + return jax_sample + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "solver_type": "bh2", + "final_sigmas_type": "sigma_min", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + state = scheduler.set_timesteps( + state, num_inference_steps, sample.shape + ) + new_state = new_scheduler.set_timesteps( + new_state, num_inference_steps, sample.shape + ) + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[ + : scheduler.config.solver_order + ]) + state = state.replace(model_outputs=initial_model_outputs) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + t = output_state.timesteps[i] + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps( + state, num_inference_steps, sample.shape + ) + + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[ + : scheduler.config.solver_order + ]) + state = state.replace(model_outputs=initial_model_outputs) + + # What is this doing? + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): + new_state = new_scheduler.set_timesteps( + new_state, num_inference_steps, sample.shape + ) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + + t = output_state.timesteps[i] + + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + + def full_loop(self, scheduler=None, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + if scheduler is None: + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + else: + state = scheduler.create_state() # Ensure state is fresh for the loop + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + + # scheduler.step in common test receives state, residual, t, sample + step_output = scheduler.step( + state=state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + sample = step_output.prev_sample + state = step_output.state # Update state for next iteration + + return sample + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() # Create initial state + + sample = self.dummy_sample # Get sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif ( + num_inference_steps is not None + and not hasattr(scheduler, "set_timesteps") + ): + kwargs["num_inference_steps"] = num_inference_steps + + # Copy over dummy past residuals (must be done after set_timesteps) + dummy_past_model_outputs = [ + 0.2 * sample, + 0.15 * sample, + 0.10 * sample, + ] + initial_model_outputs = jnp.stack(dummy_past_model_outputs[ + : scheduler.config.solver_order + ]) + state = state.replace(model_outputs=initial_model_outputs) + + time_step_0 = state.timesteps[5] + time_step_1 = state.timesteps[6] + + output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample + output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler_config = self.get_scheduler_config() + scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) + sample_1 = self.full_loop(scheduler=scheduler_1) + result_mean_1 = jnp.mean(jnp.abs(sample_1)) + + assert abs(result_mean_1.item() - 0.2464) < 1e-3 + + scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance + sample_2 = self.full_loop(scheduler=scheduler_2) + result_mean_2 = jnp.mean(jnp.abs(sample_2)) + + self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency + + assert abs(result_mean_2.item() - 0.2464) < 1e-3 + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + with self.assertRaises(NotImplementedError): + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_rescale_betas_zero_snr(self): + for rescale_zero_terminal_snr in [True, False]: + self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) + + def test_solver_order_and_type(self): + for solver_type in ["bh1", "bh2"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" + + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(time_step = 0, num_inference_steps=num_inference_steps) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2464) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2925) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1966) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + scheduler_config = self.get_scheduler_config( + thresholding=False, + dynamic_thresholding_ratio=0, + prediction_type=prediction_type, + solver_order=order, + solver_type=solver_type, + ) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.astype(jnp.bfloat16) + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + self.assertEqual(sample.dtype, jnp.bfloat16) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index :] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" + assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + + +class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + torch_sample = torch.rand((batch_size, num_channels, width)) + jax_sample= jnp.asarray(torch_sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + jax_sample= jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + jax_sample= jnp.asarray(sample) + return jax_sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2898) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1944) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index :] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" + + def test_beta_sigmas(self): + # self.check_over_configs(use_beta_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_beta_sigmas=True) + + def test_exponential_sigmas(self): + #self.check_over_configs(use_exponential_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_exponential_sigmas=True) From 57c0fcc6d5f617809208ee4cbab9a1d2dc32dd40 Mon Sep 17 00:00:00 2001 From: Quinn Date: Mon, 19 May 2025 19:48:39 +0000 Subject: [PATCH 3/4] Add an __init__.py to the test directory --- tests/schedulers/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/schedulers/__init__.py diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py new file mode 100644 index 000000000..e69de29bb From ecdece09524b378fb015f9e4d0e2c47120645550 Mon Sep 17 00:00:00 2001 From: Quinn Date: Mon, 19 May 2025 20:44:59 +0000 Subject: [PATCH 4/4] Add license to init file. --- tests/schedulers/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py index e69de29bb..b392d39a5 100644 --- a/tests/schedulers/__init__.py +++ b/tests/schedulers/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """