Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def walk_and_upload_blobs(config, output_dir):


def device_put_replicated(x, sharding):
"""
Although the name indiciates replication, this function can be used
to also shard an array based on sharding.
"""
return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index])


Expand Down
25 changes: 15 additions & 10 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,25 @@ def _tpu_flash_attention(
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""TPU Flash Attention"""

max_block_size = 1024 if dtype == jnp.bfloat16 else 512
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
# If kv seq_len is padded too much, it causes issues in attention calculations.
if key.shape[1] != query.shape[1]:
kv_max_block_size = key.shape[1]
else:
kv_max_block_size = q_max_block_size
if flash_block_sizes:
block_sizes = flash_block_sizes
else:
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(max_block_size, query.shape[2]),
block_kv_compute=min(max_block_size, key.shape[2]),
block_kv=min(max_block_size, key.shape[2]),
block_q_dkv=min(max_block_size, query.shape[2]),
block_kv_dkv=min(max_block_size, key.shape[2]),
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
block_q_dq=min(max_block_size, query.shape[2]),
block_kv_dq=min(max_block_size, query.shape[2]),
block_q=min(q_max_block_size, query.shape[2]),
block_kv_compute=min(kv_max_block_size, key.shape[2]),
block_kv=min(kv_max_block_size, key.shape[2]),
block_q_dkv=min(q_max_block_size, query.shape[2]),
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
block_q_dq=min(q_max_block_size, query.shape[2]),
block_kv_dq=min(kv_max_block_size, query.shape[2]),
)

num_fsdp_shards = mesh.shape["fsdp"]
Expand Down
32 changes: 17 additions & 15 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
if config.replicate_vae:
sharding = NamedSharding(mesh, P())
state[path].value = device_put_replicated(val, sharding)
state = nnx.from_flat_state(state)

Expand All @@ -231,11 +233,11 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
def get_basic_config(cls, dtype):
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
module_path=".*", # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
)
]
]
return rules

@classmethod
Expand All @@ -247,17 +249,17 @@ def get_fp8_config(cls, quantization_calibration_method: str):
"""
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method = quantization_calibration_method,
act_calibration_method = quantization_calibration_method,
bwd_calibration_method = quantization_calibration_method,
module_path=".*", # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=quantization_calibration_method,
act_calibration_method=quantization_calibration_method,
bwd_calibration_method=quantization_calibration_method,
)
]
]
return rules

@classmethod
Expand Down Expand Up @@ -286,7 +288,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline

batch_size = int(config.per_device_batch_size * jax.local_device_count())
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
model_inputs= (latents, timesteps, prompt_embeds)
model_inputs = (latents, timesteps, prompt_embeds)
with mesh:
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
max_logging.log("Qwix Quantization complete.")
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def wan_init(raw_keys):
if "quantization" not in raw_keys:
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
raise ValueError(
f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}"
)

@staticmethod
def calculate_global_batch_sizes(per_device_batch_size):
Expand Down
15 changes: 8 additions & 7 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from qwix import QtProvider
import os
import jax
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_get_qt_provider(self):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8)
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)

Expand All @@ -300,7 +301,7 @@ def test_get_qt_provider(self):
config_fp8.quantization = "fp8"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)

# Case 4: Quantization enabled, type 'fp8_full'
config_fp8_full = Mock(spec=HyperParameters)
Expand All @@ -309,7 +310,7 @@ def test_get_qt_provider(self):
config_fp8_full.quantization_calibration_method = "absmax"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
Expand All @@ -318,8 +319,8 @@ def test_get_qt_provider(self):
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))

# To test quantize_transformer, we patch its external dependencies
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
"""
Tests that quantize_transformer calls qwix when quantization is enabled.
Expand Down Expand Up @@ -348,14 +349,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
# Check that the model returned is the new quantized model
self.assertIs(result, mock_quantized_model_obj)

@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
def test_quantize_transformer_disabled(self, mock_quantize_model):
"""
Tests that quantize_transformer is skipped when quantization is disabled.
"""
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = False # Main condition for this test
mock_config.use_qwix_quantization = False # Main condition for this test

mock_model = Mock(spec=WanModel)

Expand Down
14 changes: 12 additions & 2 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from maxdiffusion.utils import load_video
from skimage.metrics import structural_similarity as ssim
from flax.training import train_state
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline


class TrainState(train_state.TrainState):
Expand All @@ -53,6 +54,12 @@ def generate_sample(config, pipeline, filename_prefix):
"""
Generates a video to validate training did not corrupt the model
"""
if not hasattr(pipeline, "vae"):
wan_vae, vae_cache = WanPipeline.load_vae(
pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config
)
pipeline.vae = wan_vae
pipeline.vae_cache = vae_cache
return generate_wan(config, pipeline, filename_prefix)


Expand Down Expand Up @@ -140,10 +147,13 @@ def prepare_sample(features):
def start_training(self):

pipeline = self.load_checkpoint()
# del pipeline.vae

# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")

# save some memory.
del pipeline.vae
del pipeline.vae_cache

mesh = pipeline.mesh
data_iterator = self.load_dataset(mesh)

Expand Down
Loading