Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def walk_and_upload_blobs(config, output_dir):


def device_put_replicated(x, sharding):
"""
Although the name indiciates replication, this function can be used
to also shard an array based on sharding.
"""
return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index])


Expand Down
25 changes: 15 additions & 10 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,25 @@ def _tpu_flash_attention(
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""TPU Flash Attention"""

max_block_size = 1024 if dtype == jnp.bfloat16 else 512
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
# If kv seq_len is padded too much, it causes issues in attention calculations.
if key.shape[1] != query.shape[1]:
kv_max_block_size = key.shape[1]
else:
kv_max_block_size = q_max_block_size
if flash_block_sizes:
block_sizes = flash_block_sizes
else:
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(max_block_size, query.shape[2]),
block_kv_compute=min(max_block_size, key.shape[2]),
block_kv=min(max_block_size, key.shape[2]),
block_q_dkv=min(max_block_size, query.shape[2]),
block_kv_dkv=min(max_block_size, key.shape[2]),
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
block_q_dq=min(max_block_size, query.shape[2]),
block_kv_dq=min(max_block_size, query.shape[2]),
block_q=min(q_max_block_size, query.shape[2]),
block_kv_compute=min(kv_max_block_size, key.shape[2]),
block_kv=min(kv_max_block_size, key.shape[2]),
block_q_dkv=min(q_max_block_size, query.shape[2]),
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
block_q_dq=min(q_max_block_size, query.shape[2]),
block_kv_dq=min(kv_max_block_size, query.shape[2]),
)

num_fsdp_shards = mesh.shape["fsdp"]
Expand Down
32 changes: 17 additions & 15 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
if config.replicate_vae:
sharding = NamedSharding(mesh, P())
state[path].value = device_put_replicated(val, sharding)
state = nnx.from_flat_state(state)

Expand All @@ -231,11 +233,11 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
def get_basic_config(cls, dtype):
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
module_path=".*", # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
)
]
]
return rules

@classmethod
Expand All @@ -247,17 +249,17 @@ def get_fp8_config(cls, quantization_calibration_method: str):
"""
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method = quantization_calibration_method,
act_calibration_method = quantization_calibration_method,
bwd_calibration_method = quantization_calibration_method,
module_path=".*", # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=quantization_calibration_method,
act_calibration_method=quantization_calibration_method,
bwd_calibration_method=quantization_calibration_method,
)
]
]
return rules

@classmethod
Expand Down Expand Up @@ -286,7 +288,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline

batch_size = int(config.per_device_batch_size * jax.local_device_count())
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
model_inputs= (latents, timesteps, prompt_embeds)
model_inputs = (latents, timesteps, prompt_embeds)
with mesh:
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
max_logging.log("Qwix Quantization complete.")
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def wan_init(raw_keys):
if "quantization" not in raw_keys:
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
raise ValueError(
f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}"
)

@staticmethod
def calculate_global_batch_sizes(per_device_batch_size):
Expand Down
30 changes: 9 additions & 21 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from qwix import QtProvider
import os
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -290,7 +292,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8 = WanPipeline.get_qt_provider(config_int8)
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
mock_qt_rule.assert_called_once_with(
module_path='.*',
Expand All @@ -305,11 +307,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8.quantization = "fp8"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
mock_qt_rule.assert_called_once_with(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to undo this change, sanbao fixed the test earlier. There is no attributes rules so was breaking the test.

module_path='.*',
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn
)
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)

# Case 4: Quantization enabled, type 'fp8_full'
mock_qt_rule.reset_mock()
Expand All @@ -319,17 +317,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8_full.quantization_calibration_method = "absmax"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
mock_qt_rule.assert_called_once_with(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

weight_calibration_method = config_fp8_full.quantization_calibration_method,
act_calibration_method = config_fp8_full.quantization_calibration_method,
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
)
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
Expand All @@ -338,8 +326,8 @@ def test_get_qt_provider(self, mock_qt_rule):
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))

# To test quantize_transformer, we patch its external dependencies
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
"""
Tests that quantize_transformer calls qwix when quantization is enabled.
Expand Down Expand Up @@ -370,14 +358,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
# Check that the model returned is the new quantized model
self.assertIs(result, mock_quantized_model_obj)

@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
def test_quantize_transformer_disabled(self, mock_quantize_model):
"""
Tests that quantize_transformer is skipped when quantization is disabled.
"""
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = False # Main condition for this test
mock_config.use_qwix_quantization = False # Main condition for this test

mock_model = Mock(spec=WanModel)

Expand Down
14 changes: 12 additions & 2 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from maxdiffusion.utils import load_video
from skimage.metrics import structural_similarity as ssim
from flax.training import train_state
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline


class TrainState(train_state.TrainState):
Expand All @@ -53,6 +54,12 @@ def generate_sample(config, pipeline, filename_prefix):
"""
Generates a video to validate training did not corrupt the model
"""
if not hasattr(pipeline, "vae"):
wan_vae, vae_cache = WanPipeline.load_vae(
pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config
)
pipeline.vae = wan_vae
pipeline.vae_cache = vae_cache
return generate_wan(config, pipeline, filename_prefix)


Expand Down Expand Up @@ -140,10 +147,13 @@ def prepare_sample(features):
def start_training(self):

pipeline = self.load_checkpoint()
# del pipeline.vae

# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")

# save some memory.
del pipeline.vae
del pipeline.vae_cache

mesh = pipeline.mesh
data_iterator = self.load_dataset(mesh)

Expand Down
Loading