diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index ae638e059..68c0b79fb 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -37,6 +37,10 @@ from ..models.attention_flax import FlaxWanAttention from maxdiffusion.pyconfig import HyperParameters from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +import qwix +from flax.linen import partitioning as nn_partitioning + +RealQtRule = qwix.QtRule IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -48,6 +52,17 @@ class WanTransformerTest(unittest.TestCase): def setUp(self): WanTransformerTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + self.config = pyconfig.config + devices_array = create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + def test_rotary_pos_embed(self): batch_size = 1 @@ -65,7 +80,9 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) dummy_output = layer(dummy_caption) dummy_output.shape == (1, 512, 5120) @@ -74,7 +91,8 @@ def test_nnx_timestep_embedding(self): rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) dummy_output = layer(dummy_sample) assert dummy_output.shape == (1, 5120) @@ -97,9 +115,10 @@ def test_wan_time_text_embedding(self): time_freq_dim = 256 time_proj_dim = 30720 text_embed_dim = 4096 - layer = WanTimeTextImageEmbedding( - rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) dummy_timestep = jnp.ones(batch_size) @@ -115,20 +134,8 @@ def test_wan_time_text_embedding(self): def test_wan_block(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - - devices_array = create_device_mesh(config) - flash_block_sizes = get_flash_block_sizes(config) - - mesh = Mesh(devices_array, config.mesh_axes) + flash_block_sizes = get_flash_block_sizes(self.config) dim = 5120 ffn_dim = 13824 @@ -158,33 +165,24 @@ def test_wan_block(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) - - wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - with mesh: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=self.mesh, + flash_block_sizes=flash_block_sizes, + ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape def test_wan_attention(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - batch_size = 1 channels = 16 frames = 21 @@ -197,28 +195,26 @@ def test_wan_attention(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - devices_array = create_device_mesh(config) - - flash_block_sizes = get_flash_block_sizes(config) + flash_block_sizes = get_flash_block_sizes(self.config) - mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 query_dim = 5120 - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="flash", + mesh=self.mesh, + flash_block_sizes=flash_block_sizes, + ) dummy_hidden_states_shape = (batch_size, 75600, query_dim) dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - with mesh: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): dummy_output = attention( hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb ) @@ -226,30 +222,22 @@ def test_wan_attention(self): # dot product try: - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=self.mesh, + flash_block_sizes=flash_block_sizes, + ) except NotImplementedError: pass @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - batch_size = 1 channels = 16 frames = 1 @@ -260,18 +248,16 @@ def test_wan_model(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - devices_array = create_device_mesh(config) - - flash_block_sizes = get_flash_block_sizes(config) - mesh = Mesh(devices_array, config.mesh_axes) + flash_block_sizes = get_flash_block_sizes(self.config) batch_size = 1 num_layers = 1 - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_model = WanModel(rngs=rngs, attention="flash", mesh=self.mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) - with mesh: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): dummy_output = wan_model( hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states ) @@ -282,6 +268,10 @@ def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. """ + def create_real_rule_instance(*args, **kwargs): + return RealQtRule(*args, **kwargs) + mock_qt_rule.side_effect = create_real_rule_instance + # Case 1: Quantization disabled config_disabled = Mock(spec=HyperParameters) config_disabled.use_qwix_quantization = False @@ -301,7 +291,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8 = Mock(spec=HyperParameters) config_fp8.use_qwix_quantization = True config_fp8.quantization = "fp8" - config_int8.qwix_module_path = ".*" + config_fp8.qwix_module_path = ".*" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated")) @@ -312,7 +302,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" config_fp8_full.quantization_calibration_method = "absmax" - config_int8.qwix_module_path = ".*" + config_fp8_full.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) expected_calls = [ @@ -361,6 +351,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.quantization = "fp8_full" mock_config.qwix_module_path = ".*" mock_config.per_device_batch_size = 1 + mock_config.quantization_calibration_method = "absmax" mock_model = Mock(spec=WanModel) mock_pipeline = Mock() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 7b131e7fb..442fad4a7 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -46,6 +46,7 @@ from ..models.wan.wan_utils import load_wan_vae from ..utils import load_video from ..video_processor import VideoProcessor +from flax.linen import partitioning as nn_partitioning THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -160,6 +161,16 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + self.config = pyconfig.config + devices_array = create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -209,7 +220,8 @@ def test_zero_padded_conv(self): output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) dummy_input = jnp.ones(input_shape) dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) output = model(dummy_input) @@ -247,7 +259,8 @@ def test_wan_resample(self): torch_output = torch_wan_resample(dummy_input) assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - wan_resample = WanResample(dim, mode=mode, rngs=rngs) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_resample = WanResample(dim, mode=mode, rngs=rngs) # channels is always last here input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) @@ -257,16 +270,6 @@ def test_wan_resample(self): def test_3d_conv(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 @@ -283,14 +286,15 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization, + mesh=self.mesh, + ) # --- Test Case 1: No Cache --- output_no_cache = causal_conv_layer(dummy_input) @@ -309,16 +313,6 @@ def test_3d_conv(self): def test_wan_residual(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 @@ -329,7 +323,8 @@ def test_wan_residual(self): input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=self.mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -339,7 +334,8 @@ def test_wan_residual(self): out_dim = 196 expected_output_shape = (batch, t, height, width, out_dim) - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=self.mesh) dummy_input = jnp.ones(input_shape) dummy_output = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape @@ -361,23 +357,14 @@ def test_wan_attention(self): def test_wan_midblock(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) batch = 1 t = 1 dim = 384 height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=self.mesh) dummy_input = jnp.ones(input_shape) output = wan_midblock(dummy_input) assert output.shape == input_shape @@ -385,32 +372,23 @@ def test_wan_midblock(self): def test_wan_decode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=self.mesh, + ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 t = 13 @@ -429,32 +407,23 @@ def test_wan_decode(self): def test_wan_encode(self): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=self.mesh, + ) vae_cache = AutoencoderKLWanCache(wan_vae) batch = 1 channels = 3 @@ -474,18 +443,9 @@ def vae_encode(video, wan_vae, vae_cache, key): key = jax.random.key(0) rngs = nnx.Rngs(key) - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan.from_config(self.config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=self.mesh) vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) @@ -499,7 +459,7 @@ def vae_encode(video, wan_vae, vae_cache, key): graphdef, state = nnx.split(wan_vae) params = state.to_pure_dict() # This replaces random params with the model. - params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") + params = load_wan_vae(self.config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) wan_vae = nnx.merge(graphdef, params)