Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 29 additions & 38 deletions src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
from ..utils import is_scipy_available
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)


Expand Down Expand Up @@ -164,6 +163,13 @@ def create_state(
if common is None:
common = CommonSchedulerState.create(self)

if self.config.get("rescale_zero_terminal_snr", False):
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
alphas_cumprod = common.alphas_cumprod
alphas_cumprod = alphas_cumprod.at[-1].set(2**-24)
common = common.replace(alphas_cumprod=alphas_cumprod)

# Currently we only support VP-type noise schedule
alpha_t = jnp.sqrt(common.alphas_cumprod)
sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
Expand Down Expand Up @@ -247,15 +253,24 @@ def set_timesteps(

# TODO
# # Apply Karras/Exponential/Beta/Flow Sigmas if configured
# if self.config.use_karras_sigmas:
# sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
# elif self.config.use_exponential_sigmas:
# sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
# elif self.config.use_beta_sigmas:
# sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
if self.config.use_karras_sigmas:
# sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
raise NotImplementedError(
"`use_karras_sigmas` is not implemented in JAX version yet."
)
elif self.config.use_exponential_sigmas:
# sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
raise NotImplementedError(
"`use_exponential_sigmas` is not implemented in JAX version yet."
)
elif self.config.use_beta_sigmas:
# sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
# timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
raise NotImplementedError(
"`use_beta_sigmas` is not implemented in JAX version yet."
)
if self.config.use_flow_sigmas:
alphas = jnp.linspace(
1, 1 / self.config.num_train_timesteps, num_inference_steps + 1
Expand Down Expand Up @@ -346,7 +361,7 @@ def convert_model_output(
)

if self.config.thresholding:
raise ValueError(f"Dynamic thresholding isn't implemented.")
raise NotImplementedError("Dynamic thresholding isn't implemented.")
# x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else: # self.config.predict_x0 is False
Expand Down Expand Up @@ -775,33 +790,9 @@ def add_noise(
state: UniPCMultistepSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray, # JAX array for timesteps (scalar or 0-dim array)
timesteps: jnp.ndarray,
) -> jnp.ndarray:
"""
Adds noise to `original_samples` based on the provided `timesteps` and `noise`.
"""
if timesteps.ndim > 0:
raise ValueError(
f"For `add_noise`, `timesteps` must be a scalar (0-dim JAX array), but got shape {timesteps.shape}."
)

# Get `step_indices` for the current timestep.
if state.begin_index is None:
step_idx = self.index_for_timestep(state, timesteps.item(), state.timesteps)
step_indices = jnp.array(
[step_idx]
) # Make it a 1-element array for indexing sigmas
elif state.step_index is not None:
step_indices = jnp.array([state.step_index])
else:
step_indices = jnp.array([state.begin_index])

sigma = state.sigmas[step_indices].flatten()
sigma = broadcast_to_shape_from_left(sigma, noise.shape)

alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)

def _sigma_to_alpha_sigma_t(self, sigma):
if self.config.use_flow_sigmas:
Expand Down
Loading
Loading