Skip to content

Commit 9846a13

Browse files
committed
Merge branch 'main' into update_readme_wan
2 parents 286e452 + fb1c00b commit 9846a13

11 files changed

Lines changed: 153 additions & 76 deletions

File tree

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ enable_profiler: False
236236
# the iteration time a chance to stabilize.
237237
skip_first_n_steps_for_profiler: 5
238238
profiler_steps: 10
239+
profiler: ""
239240

240241
# Generation parameters
241242
prompt: "A magical castle in the middle of a forest, artistic drawing"
@@ -284,3 +285,5 @@ quantization: ''
284285
quantization_local_shard_count: -1
285286
use_qwix_quantization: False
286287
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
288+
289+
save_final_checkpoint: False

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,9 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int
313313
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
314314
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
315315
quantization_calibration_method: "absmax"
316+
qwix_module_path: ".*"
316317

317318
# Eval model on per eval_every steps. -1 means don't eval.
318319
eval_every: -1
319320
eval_data_dir: ""
320321
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321-
eval_max_number_of_samples_in_bucket: 60
322-
eval_max_processed_batch_size: 8 # This is the max batch size per device for eval step. If the global eval batch size is larger than this, the eval step will be run multiple times.

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl import app
2222
from maxdiffusion.utils import export_to_video
2323
from google.cloud import storage
24+
import flax
2425

2526

2627
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -161,6 +162,7 @@ def run(config, pipeline=None, filename_prefix=""):
161162

162163
def main(argv: Sequence[str]) -> None:
163164
pyconfig.initialize(argv)
165+
flax.config.update('flax_always_shard_variable', False)
164166
run(pyconfig.config)
165167

166168

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,34 +243,48 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
243243
return wan_vae, vae_cache
244244

245245
@classmethod
246-
def get_basic_config(cls, dtype):
246+
def get_basic_config(cls, dtype, config: HyperParameters):
247247
rules = [
248248
qwix.QtRule(
249-
module_path=".*", # Apply to all modules
249+
module_path=config.qwix_module_path,
250250
weight_qtype=dtype,
251251
act_qtype=dtype,
252+
op_names=("dot_general","einsum", "conv_general_dilated"),
252253
)
253254
]
254255
return rules
255256

256257
@classmethod
257-
def get_fp8_config(cls, quantization_calibration_method: str):
258+
def get_fp8_config(cls, config: HyperParameters):
258259
"""
259260
fp8 config rules with per-tensor calibration.
260261
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
261262
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
262263
"""
263264
rules = [
264265
qwix.QtRule(
265-
module_path=".*", # Apply to all modules
266+
module_path=config.qwix_module_path,
266267
weight_qtype=jnp.float8_e4m3fn,
267268
act_qtype=jnp.float8_e4m3fn,
269+
bwd_qtype=jnp.float8_e5m2,
270+
bwd_use_original_residuals=True,
271+
disable_channelwise_axes=True, # per_tensor calibration
272+
weight_calibration_method=config.quantization_calibration_method,
273+
act_calibration_method=config.quantization_calibration_method,
274+
bwd_calibration_method=config.quantization_calibration_method,
275+
op_names=("dot_general","einsum"),
276+
),
277+
qwix.QtRule(
278+
module_path=config.qwix_module_path,
279+
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
280+
act_qtype=jnp.float8_e4m3fn,
268281
bwd_qtype=jnp.float8_e4m3fn,
269282
bwd_use_original_residuals=True,
270283
disable_channelwise_axes=True, # per_tensor calibration
271-
weight_calibration_method=quantization_calibration_method,
272-
act_calibration_method=quantization_calibration_method,
273-
bwd_calibration_method=quantization_calibration_method,
284+
weight_calibration_method=config.quantization_calibration_method,
285+
act_calibration_method=config.quantization_calibration_method,
286+
bwd_calibration_method=config.quantization_calibration_method,
287+
op_names=("conv_general_dilated"),
274288
)
275289
]
276290
return rules
@@ -281,14 +295,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
281295
if not getattr(config, "use_qwix_quantization", False):
282296
return None
283297

284-
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
285298
match config.quantization:
286299
case "int8":
287-
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
300+
return qwix.QtProvider(cls.get_basic_config(jnp.int8, config))
288301
case "fp8":
289-
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
302+
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config))
290303
case "fp8_full":
291-
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
304+
return qwix.QtProvider(cls.get_fp8_config(config))
292305
return None
293306

294307
@classmethod

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22-
from unittest.mock import Mock, patch
22+
from unittest.mock import Mock, patch, call
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
@@ -37,6 +37,11 @@
3737
from ..models.attention_flax import FlaxWanAttention
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
import qwix
41+
import flax
42+
43+
flax.config.update('flax_always_shard_variable', False)
44+
RealQtRule = qwix.QtRule
4045

4146

4247
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
@@ -282,6 +287,10 @@ def test_get_qt_provider(self, mock_qt_rule):
282287
"""
283288
Tests the provider logic for all config branches.
284289
"""
290+
def create_real_rule_instance(*args, **kwargs):
291+
return RealQtRule(*args, **kwargs)
292+
mock_qt_rule.side_effect = create_real_rule_instance
293+
285294
# Case 1: Quantization disabled
286295
config_disabled = Mock(spec=HyperParameters)
287296
config_disabled.use_qwix_quantization = False
@@ -291,28 +300,43 @@ def test_get_qt_provider(self, mock_qt_rule):
291300
config_int8 = Mock(spec=HyperParameters)
292301
config_int8.use_qwix_quantization = True
293302
config_int8.quantization = "int8"
303+
config_int8.qwix_module_path = ".*"
294304
provider_int8 = WanPipeline.get_qt_provider(config_int8)
295305
self.assertIsNotNone(provider_int8)
296-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8)
306+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated"))
297307

298308
# Case 3: Quantization enabled, type 'fp8'
299309
mock_qt_rule.reset_mock()
300310
config_fp8 = Mock(spec=HyperParameters)
301311
config_fp8.use_qwix_quantization = True
302312
config_fp8.quantization = "fp8"
313+
config_fp8.qwix_module_path = ".*"
303314
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
304315
self.assertIsNotNone(provider_fp8)
305-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn)
316+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))
306317

307318
# Case 4: Quantization enabled, type 'fp8_full'
308319
mock_qt_rule.reset_mock()
309320
config_fp8_full = Mock(spec=HyperParameters)
310321
config_fp8_full.use_qwix_quantization = True
311322
config_fp8_full.quantization = "fp8_full"
312323
config_fp8_full.quantization_calibration_method = "absmax"
324+
config_fp8_full.qwix_module_path = ".*"
313325
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
314326
self.assertIsNotNone(provider_fp8_full)
315-
mock_qt_rule.assert_called_once_with(
327+
expected_calls = [
328+
call(module_path=".*", # Apply to all modules
329+
weight_qtype=jnp.float8_e4m3fn,
330+
act_qtype=jnp.float8_e4m3fn,
331+
bwd_qtype=jnp.float8_e5m2,
332+
bwd_use_original_residuals=True,
333+
disable_channelwise_axes=True, # per_tensor calibration
334+
weight_calibration_method=config_fp8_full.quantization_calibration_method,
335+
act_calibration_method=config_fp8_full.quantization_calibration_method,
336+
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
337+
op_names=("dot_general","einsum"),
338+
),
339+
call(
316340
module_path=".*", # Apply to all modules
317341
weight_qtype=jnp.float8_e4m3fn,
318342
act_qtype=jnp.float8_e4m3fn,
@@ -322,7 +346,10 @@ def test_get_qt_provider(self, mock_qt_rule):
322346
weight_calibration_method=config_fp8_full.quantization_calibration_method,
323347
act_calibration_method=config_fp8_full.quantization_calibration_method,
324348
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
325-
)
349+
op_names=("conv_general_dilated"),
350+
)
351+
]
352+
mock_qt_rule.assert_has_calls(expected_calls, any_order=True)
326353

327354
# Case 5: Invalid quantization type
328355
config_invalid = Mock(spec=HyperParameters)
@@ -341,7 +368,9 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
341368
mock_config = Mock(spec=HyperParameters)
342369
mock_config.use_qwix_quantization = True
343370
mock_config.quantization = "fp8_full"
371+
mock_config.qwix_module_path = ".*"
344372
mock_config.per_device_batch_size = 1
373+
mock_config.quantization_calibration_method = "absmax"
345374

346375
mock_model = Mock(spec=WanModel)
347376
mock_pipeline = Mock()

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@
4646
from ..models.wan.wan_utils import load_wan_vae
4747
from ..utils import load_video
4848
from ..video_processor import VideoProcessor
49+
import flax
4950

5051
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
5152

5253
CACHE_T = 2
5354

54-
55+
flax.config.update('flax_always_shard_variable', False)
5556
class TorchWanRMS_norm(nn.Module):
5657
r"""
5758
A custom RMS normalization layer.

src/maxdiffusion/train_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from maxdiffusion.train_utils import (
2424
validate_train_config,
25+
transformer_engine_context,
2526
)
2627

2728

@@ -39,6 +40,6 @@ def main(argv: Sequence[str]) -> None:
3940
max_logging.log(f"Found {jax.device_count()} devices.")
4041
train(config)
4142

42-
4343
if __name__ == "__main__":
44-
app.run(main)
44+
with transformer_engine_context():
45+
app.run(main)

src/maxdiffusion/train_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from maxdiffusion.train_utils import (
2929
validate_train_config,
30+
transformer_engine_context,
3031
)
3132

3233

@@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None:
5152
tf.config.set_visible_devices([], "GPU")
5253
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
5354
torch.set_default_device("cpu")
54-
app.run(main)
55+
with transformer_engine_context():
56+
app.run(main)

src/maxdiffusion/train_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import queue
2121

2222
from maxdiffusion import max_utils, max_logging
23-
23+
from contextlib import contextmanager
2424

2525
def get_first_step(state):
2626
return int(state.step)
@@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps):
196196
weights[bias_indices] *= timestep_bias_config["multiplier"]
197197
weights /= weights.sum()
198198
return jnp.array(weights)
199+
200+
201+
@contextmanager
202+
def transformer_engine_context():
203+
""" If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
204+
try:
205+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
206+
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
207+
mesh_resource = MeshResource(
208+
dp_resource = "data",
209+
tp_resource = "tensor",
210+
fsdp_resource = "fsdp",
211+
pp_resource = None,
212+
cp_resource = None,
213+
)
214+
with global_shard_guard(mesh_resource):
215+
yield
216+
except ImportError:
217+
yield

src/maxdiffusion/train_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import app
2121
from maxdiffusion import max_logging, pyconfig
2222
from maxdiffusion.train_utils import validate_train_config
23+
import flax
2324

2425

2526
def train(config):
@@ -34,6 +35,7 @@ def main(argv: Sequence[str]) -> None:
3435
config = pyconfig.config
3536
validate_train_config(config)
3637
max_logging.log(f"Found {jax.device_count()} devices.")
38+
flax.config.update('flax_always_shard_variable', False)
3739
train(config)
3840

3941

0 commit comments

Comments
 (0)