diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1659d3bb5..78e3322d1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -286,7 +286,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return model max_logging.log("Quantizing transformer with Qwix.") - batch_size = int(config.per_device_batch_size * jax.local_device_count()) + batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) model_inputs = (latents, timesteps, prompt_embeds) with mesh: