Skip to content

Commit 682cbce

Browse files
committed
changes needed for scheduler
1 parent 351d345 commit 682cbce

2 files changed

Lines changed: 37 additions & 28 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,25 +186,17 @@ def retrieve_timesteps(
186186
):
187187
if timesteps is not None and sigmas is not None:
188188
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
189-
190-
if timesteps is not None:
191-
# TODO: Support custom timesteps in FlaxFlowMatchScheduler
192-
raise NotImplementedError("Custom timesteps not yet supported in FlaxFlowMatchScheduler wrapper.")
193-
elif sigmas is not None:
194-
# Manually create state with custom sigmas
195-
# Replicates logic from diffusers but for Flax state
196-
sigmas = jnp.array(sigmas, dtype=scheduler.dtype)
197-
# Assuming scheduler.config.num_train_timesteps exists
198-
timesteps = sigmas * scheduler.config.num_train_timesteps
199189

200-
# We need to update the state with these new values
201-
scheduler_state = scheduler_state.replace(
202-
sigmas=sigmas,
203-
timesteps=timesteps,
204-
num_inference_steps=len(sigmas)
205-
)
206-
else:
207-
scheduler_state = scheduler.set_timesteps(scheduler_state, num_inference_steps, **kwargs)
190+
timesteps = jnp.array(timesteps, dtype=scheduler.dtype) if timesteps is not None else None
191+
sigmas = jnp.array(sigmas, dtype=scheduler.dtype) if sigmas is not None else None
192+
193+
scheduler_state = scheduler.set_timesteps(
194+
scheduler_state,
195+
num_inference_steps=num_inference_steps,
196+
timesteps=timesteps,
197+
sigmas=sigmas,
198+
**kwargs,
199+
)
208200

209201
return scheduler_state
210202

src/maxdiffusion/schedulers/scheduling_flow_match_flax.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def set_timesteps(
109109
denoising_strength: float = 1.0,
110110
training: bool = False,
111111
shift: Optional[float] = None,
112+
timesteps: Optional[jnp.ndarray] = None,
113+
sigmas: Optional[jnp.ndarray] = None,
112114
) -> FlowMatchSchedulerState:
113115
"""
114116
Sets the discrete timesteps used for the diffusion chain.
@@ -126,27 +128,42 @@ def set_timesteps(
126128
Whether the scheduler is being used for training.
127129
shift (`Optional[float]`):
128130
An optional shift value to override the one in the config.
131+
timesteps (`Optional[jnp.ndarray]`):
132+
Custom timesteps to use for the denoising process.
133+
sigmas (`Optional[jnp.ndarray]`):
134+
Custom sigmas to use for the denoising process.
129135
130136
Returns:
131137
`FlowMatchSchedulerState`: The updated scheduler state.
132138
"""
133139
current_shift = shift if shift is not None else self.config.shift
134-
sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength
135140

136-
if self.config.extra_one_step:
137-
sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1]
141+
if timesteps is not None and sigmas is not None:
142+
pass
143+
elif timesteps is not None:
144+
sigmas = timesteps / self.config.num_train_timesteps
145+
elif sigmas is not None:
146+
timesteps = sigmas * self.config.num_train_timesteps
138147
else:
139-
sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype)
148+
sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength
140149

141-
if self.config.inverse_timesteps:
142-
sigmas = jnp.flip(sigmas, dims=[0])
150+
if self.config.extra_one_step:
151+
sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1]
152+
else:
153+
sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype)
143154

144-
sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas)
155+
if self.config.inverse_timesteps:
156+
sigmas = jnp.flip(sigmas, dims=[0])
145157

146-
if self.config.reverse_sigmas:
147-
sigmas = 1 - sigmas
158+
sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas)
148159

149-
timesteps = sigmas * self.config.num_train_timesteps
160+
if self.config.reverse_sigmas:
161+
sigmas = 1 - sigmas
162+
163+
timesteps = sigmas * self.config.num_train_timesteps
164+
165+
if timesteps is not None:
166+
num_inference_steps = len(timesteps)
150167

151168
linear_timesteps_weights = None
152169
if training:

0 commit comments

Comments
 (0)