Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 73 additions & 82 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
from ..models.attention_flax import FlaxWanAttention
from maxdiffusion.pyconfig import HyperParameters
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
import qwix
from flax.linen import partitioning as nn_partitioning

RealQtRule = qwix.QtRule


IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
Expand All @@ -48,6 +52,17 @@ class WanTransformerTest(unittest.TestCase):

def setUp(self):
WanTransformerTest.dummy_data = {}
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
self.config = pyconfig.config
devices_array = create_device_mesh(self.config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)


def test_rotary_pos_embed(self):
batch_size = 1
Expand All @@ -65,7 +80,9 @@ def test_nnx_pixart_alpha_text_projection(self):
key = jax.random.key(0)
rngs = nnx.Rngs(key)
dummy_caption = jnp.ones((1, 512, 4096))
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
dummy_output = layer(dummy_caption)
dummy_output.shape == (1, 512, 5120)

Expand All @@ -74,7 +91,8 @@ def test_nnx_timestep_embedding(self):
rngs = nnx.Rngs(key)

dummy_sample = jnp.ones((1, 256))
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
dummy_output = layer(dummy_sample)
assert dummy_output.shape == (1, 5120)

Expand All @@ -97,9 +115,10 @@ def test_wan_time_text_embedding(self):
time_freq_dim = 256
time_proj_dim = 30720
text_embed_dim = 4096
layer = WanTimeTextImageEmbedding(
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
layer = WanTimeTextImageEmbedding(
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
)

dummy_timestep = jnp.ones(batch_size)

Expand All @@ -115,20 +134,8 @@ def test_wan_time_text_embedding(self):
def test_wan_block(self):
key = jax.random.key(0)
rngs = nnx.Rngs(key)
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
config = pyconfig.config

devices_array = create_device_mesh(config)

flash_block_sizes = get_flash_block_sizes(config)

mesh = Mesh(devices_array, config.mesh_axes)
flash_block_sizes = get_flash_block_sizes(self.config)

dim = 5120
ffn_dim = 13824
Expand Down Expand Up @@ -158,33 +165,24 @@ def test_wan_block(self):
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))

dummy_temb = jnp.ones((batch_size, 6, dim))

wan_block = WanTransformerBlock(
rngs=rngs,
dim=dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
attention="flash",
mesh=mesh,
flash_block_sizes=flash_block_sizes,
)
with mesh:
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
wan_block = WanTransformerBlock(
rngs=rngs,
dim=dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
attention="flash",
mesh=self.mesh,
flash_block_sizes=flash_block_sizes,
)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
assert dummy_output.shape == dummy_hidden_states.shape

def test_wan_attention(self):
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
config = pyconfig.config

batch_size = 1
channels = 16
frames = 21
Expand All @@ -197,59 +195,49 @@ def test_wan_attention(self):

key = jax.random.key(0)
rngs = nnx.Rngs(key)
devices_array = create_device_mesh(config)

flash_block_sizes = get_flash_block_sizes(config)
flash_block_sizes = get_flash_block_sizes(self.config)

mesh = Mesh(devices_array, config.mesh_axes)
batch_size = 1
query_dim = 5120
attention = FlaxWanAttention(
rngs=rngs,
query_dim=query_dim,
heads=40,
dim_head=128,
attention_kernel="flash",
mesh=mesh,
flash_block_sizes=flash_block_sizes,
)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
attention = FlaxWanAttention(
rngs=rngs,
query_dim=query_dim,
heads=40,
dim_head=128,
attention_kernel="flash",
mesh=self.mesh,
flash_block_sizes=flash_block_sizes,
)

dummy_hidden_states_shape = (batch_size, 75600, query_dim)

dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
with mesh:
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
dummy_output = attention(
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
)
assert dummy_output.shape == dummy_hidden_states_shape

# dot product
try:
attention = FlaxWanAttention(
rngs=rngs,
query_dim=query_dim,
heads=40,
dim_head=128,
attention_kernel="dot_product",
split_head_dim=True,
mesh=mesh,
flash_block_sizes=flash_block_sizes,
)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
attention = FlaxWanAttention(
rngs=rngs,
query_dim=query_dim,
heads=40,
dim_head=128,
attention_kernel="dot_product",
split_head_dim=True,
mesh=self.mesh,
flash_block_sizes=flash_block_sizes,
)
except NotImplementedError:
pass

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_wan_model(self):
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
config = pyconfig.config

batch_size = 1
channels = 16
frames = 1
Expand All @@ -260,18 +248,16 @@ def test_wan_model(self):

key = jax.random.key(0)
rngs = nnx.Rngs(key)
devices_array = create_device_mesh(config)

flash_block_sizes = get_flash_block_sizes(config)

mesh = Mesh(devices_array, config.mesh_axes)
flash_block_sizes = get_flash_block_sizes(self.config)
batch_size = 1
num_layers = 1
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
wan_model = WanModel(rngs=rngs, attention="flash", mesh=self.mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)

dummy_timestep = jnp.ones((batch_size))
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))
with mesh:
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
dummy_output = wan_model(
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
)
Expand All @@ -282,6 +268,10 @@ def test_get_qt_provider(self, mock_qt_rule):
"""
Tests the provider logic for all config branches.
"""
def create_real_rule_instance(*args, **kwargs):
return RealQtRule(*args, **kwargs)
mock_qt_rule.side_effect = create_real_rule_instance

# Case 1: Quantization disabled
config_disabled = Mock(spec=HyperParameters)
config_disabled.use_qwix_quantization = False
Expand All @@ -301,7 +291,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8 = Mock(spec=HyperParameters)
config_fp8.use_qwix_quantization = True
config_fp8.quantization = "fp8"
config_int8.qwix_module_path = ".*"
config_fp8.qwix_module_path = ".*"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))
Expand All @@ -312,7 +302,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8_full.use_qwix_quantization = True
config_fp8_full.quantization = "fp8_full"
config_fp8_full.quantization_calibration_method = "absmax"
config_int8.qwix_module_path = ".*"
config_fp8_full.qwix_module_path = ".*"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
expected_calls = [
Expand Down Expand Up @@ -361,6 +351,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
mock_config.quantization = "fp8_full"
mock_config.qwix_module_path = ".*"
mock_config.per_device_batch_size = 1
mock_config.quantization_calibration_method = "absmax"

mock_model = Mock(spec=WanModel)
mock_pipeline = Mock()
Expand Down
Loading
Loading