From 3b38c40c274efff75a7eb976bffaf4429658eabb Mon Sep 17 00:00:00 2001 From: Vijaya Date: Thu, 10 Apr 2025 21:53:15 +0000 Subject: [PATCH] Fix bug with AQT changes --- .../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 3673317b6..e3d96d6e9 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -239,7 +239,7 @@ def loop_body(step, args): # predict the noise residual noise_pred = self.unet.apply( - {"params": params["unet"], "aqt": params["unet"]["aqt"]}, + {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=prompt_embeds,