Skip to content

Commit dfd0c89

Browse files
committed
jitting the step() method in scheduler
1 parent c905459 commit dfd0c89

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/schedulers/scheduling_flow_match_flax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# DISCLAIMER: This is a JAX/Flax conversion of a PyTorch implementation.
1616
# The original PyTorch code was provided by the user.
1717

18+
from functools import partial
1819
from typing import Optional, Tuple, Union
1920

2021
import flax
@@ -243,6 +244,7 @@ def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarra
243244
diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None])
244245
return jnp.argmin(diffs, axis=1)
245246

247+
@partial(jax.jit, static_argnums=(0, 5, 6))
246248
def step(
247249
self,
248250
state: FlowMatchSchedulerState,

0 commit comments

Comments
 (0)