2121import jax
2222import jax .numpy as jnp
2323from flax import nnx
24- from flax import linen as nn
2524from flax .linen import partitioning as nn_partitioning
25+ from flax .linen import logical_to_mesh_sharding
2626from jax .sharding import Mesh
2727from .. import pyconfig
2828from ..max_utils import (
@@ -224,7 +224,7 @@ def test_zero_padded_conv(self):
224224 output_torch = resample (input )
225225 assert output_torch .shape == (1 , 96 , 240 , 360 )
226226
227- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
227+ with self .mesh , nn_partitioning .axis_rules (self .config .vae_logical_axis_rules ):
228228 model = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (1 , 3 , 3 ), stride = (1 , 2 , 2 ))
229229 dummy_input = jnp .ones (input_shape )
230230 dummy_input = jnp .transpose (dummy_input , (0 , 2 , 3 , 1 ))
@@ -262,7 +262,7 @@ def test_wan_resample(self):
262262 torch_wan_resample = TorchWanResample (dim = dim , mode = mode )
263263 torch_output = torch_wan_resample (dummy_input )
264264 assert torch_output .shape == (batch , dim , t , h // 2 , w // 2 )
265- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
265+ with self .mesh , nn_partitioning .axis_rules (self .config .vae_logical_axis_rules ):
266266 wan_resample = WanResample (dim , mode = mode , rngs = rngs )
267267 # channels is always last here
268268 input_shape = (batch , t , h , w , dim )
@@ -305,7 +305,7 @@ def test_3d_conv(self):
305305 dummy_cache = jnp .zeros ((batch_size , cache_depth , in_height , in_width , in_channels ))
306306
307307 # Instantiate the module
308- with self .mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
308+ with self .mesh , nn_partitioning .axis_rules (config .vae_logical_axis_rules ):
309309 causal_conv_layer = WanCausalConv3d (
310310 in_channels = in_channels ,
311311 out_channels = out_channels ,
@@ -357,7 +357,7 @@ def test_wan_residual(self):
357357 dim = 96
358358 input_shape = (batch , t , height , width , dim )
359359 expected_output_shape = (batch , t , height , width , dim )
360- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
360+ with mesh , nn_partitioning .axis_rules (config .vae_logical_axis_rules ):
361361 wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = mesh )
362362 dummy_input = jnp .ones (input_shape )
363363 dummy_output , _ , _ = wan_residual_block (dummy_input )
@@ -381,7 +381,7 @@ def test_wan_attention(self):
381381 height = 60
382382 width = 90
383383 input_shape = (batch , t , height , width , dim )
384- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
384+ with self .mesh , nn_partitioning .axis_rules (self .config .vae_logical_axis_rules ):
385385 wan_attention = WanAttentionBlock (dim = dim , rngs = rngs )
386386 dummy_input = jnp .ones (input_shape )
387387 output , _ , _ = wan_attention (dummy_input )
@@ -412,7 +412,7 @@ def test_wan_midblock(self):
412412 height = 60
413413 width = 90
414414 input_shape = (batch , t , height , width , dim )
415- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
415+ with mesh , nn_partitioning .axis_rules (config .vae_logical_axis_rules ):
416416 wan_midblock = WanMidBlock (dim = dim , rngs = rngs , mesh = mesh )
417417 dummy_input = jnp .ones (input_shape )
418418 output , _ , _ = wan_midblock (dummy_input )
@@ -443,7 +443,7 @@ def test_wan_decode(self):
443443 num_res_blocks = 2
444444 attn_scales = []
445445 temperal_downsample = [False , True , True ]
446- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
446+ with mesh , nn_partitioning .axis_rules (config .vae_logical_axis_rules ):
447447 wan_vae = AutoencoderKLWan (
448448 rngs = rngs ,
449449 base_dim = dim ,
@@ -494,7 +494,7 @@ def test_wan_encode(self):
494494 num_res_blocks = 2
495495 attn_scales = []
496496 temperal_downsample = [False , True , True ]
497- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
497+ with mesh , nn_partitioning .axis_rules (config .vae_logical_axis_rules ):
498498 wan_vae = AutoencoderKLWan (
499499 rngs = rngs ,
500500 base_dim = dim ,
@@ -540,7 +540,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
540540 # Reshape devices to include vae_spatial (size 1 for test)
541541 devices_array = devices_array .reshape (devices_array .shape + (1 ,))
542542 mesh = Mesh (devices_array , mesh_axes )
543- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
543+ with self .mesh , nn_partitioning .axis_rules (self .config .vae_logical_axis_rules ):
544544 wan_vae = AutoencoderKLWan .from_config (config .pretrained_model_name_or_path , subfolder = "vae" , rngs = rngs , mesh = mesh )
545545 vae_cache = AutoencoderKLWanCache (wan_vae )
546546 video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
@@ -559,7 +559,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
559559 params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), params )
560560
561561 logical_state_spec = nnx .get_partition_spec (state )
562- logical_state_sharding = nn . logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
562+ logical_state_sharding = logical_to_mesh_sharding (logical_state_spec , mesh , config .vae_logical_axis_rules )
563563 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
564564
565565 state_flat = dict (nnx .to_flat_state (state ))
0 commit comments