Skip to content

Commit a385c8e

Browse files
authored
fp8 bug for batch_size setting error (#317)
* fp8 bug * format * unit test * unit test
1 parent 5a05e75 commit a385c8e

2 files changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
376376
return model
377377
max_logging.log("Quantizing transformer with Qwix.")
378378

379-
batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32)
379+
batch_size = config.global_batch_size_to_train_on
380380
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
381381
model_inputs = (latents, timesteps, prompt_embeds)
382382
with mesh:

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
396396
mock_config.weight_quantization_calibration_method = "fixed,-224,224"
397397
mock_config.act_quantization_calibration_method = "fixed,-224,224"
398398
mock_config.bwd_quantization_calibration_method = "absmax"
399+
mock_config.global_batch_size_to_train_on = 32
399400

400401
mock_model = Mock(spec=WanModel)
401402
mock_pipeline = Mock()

0 commit comments

Comments
 (0)