diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3cbb0ccea..05e2db498 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -214,8 +214,8 @@ def _tpu_flash_attention( def wrap_flash_attention(query, key, value): query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q) - key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute) - value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv) mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index db1216f72..5afa3bff8 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -1,12 +1,15 @@ import os import unittest +from unittest.mock import Mock import jax from jax.sharding import Mesh import flax.linen as nn from absl.testing import absltest from maxdiffusion.max_utils import calculate_model_tflops from maxdiffusion.models.attention_flax import FlaxAttention +from maxdiffusion.models.wan.transformers.transformer_wan import WanModel from .. import pyconfig, max_utils +from maxdiffusion.trainers.wan_trainer import WanTrainer THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -20,6 +23,39 @@ def setUp(self): devices_array = max_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) + def assertFlopsAlmostEqual(self, flops1, flops2, rel_tol=5e-2): + """Assert that two FLOPs values are almost equal, within 5% relative tolerance.""" + self.assertTrue( + abs(flops1 - flops2) / max(abs(flops1), abs(flops2)) <= rel_tol, + f"FLOPs values are not equal: {flops1} != {flops2} (rel_tol={rel_tol:.2e})", + ) + + def test_wan_21_flops(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + "width=1280", + "height=720", + "num_frames=81", + "per_device_batch_size=1", + ], + unittest=True, + ) + config = pyconfig.config + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + pipeline = Mock() + pipeline.config = config + pipeline.vae_scale_factor_temporal = 4 + transformer = Mock() + transformer.config = Mock() + transformer.config.configure_mock(**wan_config) + pipeline.transformer = transformer + + calculated_tflops, attention_flops, seq_len = WanTrainer.calculate_tflops(pipeline) + golden_tflops = 19_573 + self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) + def test_dense_layer_model_flops(self): class SimpleLinearModel(nn.Module): diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 2c6caf579..1b235f640 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -100,9 +100,56 @@ def create_scheduler(self): noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) return noise_scheduler, noise_scheduler_state - def calculate_tflops(self, pipeline): - max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") - return 0 + @staticmethod + def calculate_tflops(pipeline): + + maxdiffusion_config = pipeline.config + # Model configuration + height = pipeline.config.height + width = pipeline.config.width + num_frames = pipeline.config.num_frames + + # Transformer dimensions + transformer_config = pipeline.transformer.config + num_layers = transformer_config.num_layers + heads = pipeline.transformer.config.num_attention_heads + head_dim = pipeline.transformer.config.attention_head_dim + ffn_dim = transformer_config.ffn_dim + seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4) + text_encoder_dim = 512 + # Attention FLOPS + # Self + self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2) + self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim)) + # Cross + cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2) + cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2) + cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim)) + + # Output_projection from attention + attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2) + + total_attn_flops = ( + self_attn_qkv_proj_flops + + self_attn_qk_v_flops + + cross_attn_kv_proj_flops + + cross_attn_q_proj_flops + + cross_attention_qk_v_flops + + attn_output_proj_flops + ) + + # FFN + ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim) + + flops_per_block = total_attn_flops + ffn_flops + + total_transformer_flops = flops_per_block * num_layers + + tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12 + train_tflops = 3 * tflops + + max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}") + return train_tflops, total_attn_flops, seq_len def get_data_shardings(self, mesh): data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) @@ -225,7 +272,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data ) # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint. start_step = 0 - per_device_tflops = self.calculate_tflops(pipeline) + per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config)