Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def _tpu_flash_attention(
def wrap_flash_attention(query, key, value):

query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute)
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv)

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
Expand Down
36 changes: 36 additions & 0 deletions src/maxdiffusion/tests/flop_calculations_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import unittest
from unittest.mock import Mock
import jax
from jax.sharding import Mesh
import flax.linen as nn
from absl.testing import absltest
from maxdiffusion.max_utils import calculate_model_tflops
from maxdiffusion.models.attention_flax import FlaxAttention
from maxdiffusion.models.wan.transformers.transformer_wan import WanModel
from .. import pyconfig, max_utils
from maxdiffusion.trainers.wan_trainer import WanTrainer

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -20,6 +23,39 @@ def setUp(self):
devices_array = max_utils.create_device_mesh(self.config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)

def assertFlopsAlmostEqual(self, flops1, flops2, rel_tol=5e-2):
"""Assert that two FLOPs values are almost equal, within 5% relative tolerance."""
self.assertTrue(
abs(flops1 - flops2) / max(abs(flops1), abs(flops2)) <= rel_tol,
f"FLOPs values are not equal: {flops1} != {flops2} (rel_tol={rel_tol:.2e})",
)

def test_wan_21_flops(self):
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
"width=1280",
"height=720",
"num_frames=81",
"per_device_batch_size=1",
],
unittest=True,
)
config = pyconfig.config
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
pipeline = Mock()
pipeline.config = config
pipeline.vae_scale_factor_temporal = 4
transformer = Mock()
transformer.config = Mock()
transformer.config.configure_mock(**wan_config)
pipeline.transformer = transformer

calculated_tflops, attention_flops, seq_len = WanTrainer.calculate_tflops(pipeline)
golden_tflops = 19_573
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

def test_dense_layer_model_flops(self):
class SimpleLinearModel(nn.Module):

Expand Down
55 changes: 51 additions & 4 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,56 @@ def create_scheduler(self):
noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True)
return noise_scheduler, noise_scheduler_state

def calculate_tflops(self, pipeline):
max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...")
return 0
@staticmethod
def calculate_tflops(pipeline):

maxdiffusion_config = pipeline.config
# Model configuration
height = pipeline.config.height
width = pipeline.config.width
num_frames = pipeline.config.num_frames

# Transformer dimensions
transformer_config = pipeline.transformer.config
num_layers = transformer_config.num_layers
heads = pipeline.transformer.config.num_attention_heads
head_dim = pipeline.transformer.config.attention_head_dim
ffn_dim = transformer_config.ffn_dim
seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4)
text_encoder_dim = 512
# Attention FLOPS
# Self
self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2)
self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim))
# Cross
cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2)
cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2)
cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim))

# Output_projection from attention
attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2)

total_attn_flops = (
self_attn_qkv_proj_flops
+ self_attn_qk_v_flops
+ cross_attn_kv_proj_flops
+ cross_attn_q_proj_flops
+ cross_attention_qk_v_flops
+ attn_output_proj_flops
)

# FFN
ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim)

flops_per_block = total_attn_flops + ffn_flops

total_transformer_flops = flops_per_block * num_layers

tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12
train_tflops = 3 * tflops

max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}")
return train_tflops, total_attn_flops, seq_len

def get_data_shardings(self, mesh):
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
Expand Down Expand Up @@ -225,7 +272,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
)
# TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
start_step = 0
per_device_tflops = self.calculate_tflops(pipeline)
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
scheduler_state = pipeline.scheduler_state
example_batch = load_next_batch(train_data_iterator, None, self.config)

Expand Down
Loading