diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 3673317b6..e3d96d6e9 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -239,7 +239,7 @@ def loop_body(step, args): # predict the noise residual noise_pred = self.unet.apply( - {"params": params["unet"], "aqt": params["unet"]["aqt"]}, + {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=prompt_embeds,