Skip to content

Commit 224a951

Browse files
authored
Solve the qwix bugs when config.per_device_batch_size * jax.local_device_count() < 1 (#237)
* fix wan unit test bugs * line problems * bug fix under low per_device_batch_size
1 parent 955bd86 commit 224a951

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
286286
return model
287287
max_logging.log("Quantizing transformer with Qwix.")
288288

289-
batch_size = int(config.per_device_batch_size * jax.local_device_count())
289+
batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32)
290290
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
291291
model_inputs = (latents, timesteps, prompt_embeds)
292292
with mesh:

0 commit comments

Comments
 (0)