3838from maxdiffusion .pyconfig import HyperParameters
3939from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
4040import qwix
41+ import numpy as np
4142
4243RealQtRule = qwix .QtRule
4344
@@ -68,16 +69,9 @@ def test_nnx_pixart_alpha_text_projection(self):
6869 key = jax .random .key (0 )
6970 rngs = nnx .Rngs (key )
7071 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
71- pyconfig .initialize (
72- [
73- None ,
74- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
75- ],
76- unittest = True ,
77- )
78- config = pyconfig .config
79- devices_array = create_device_mesh (config )
80- mesh = Mesh (devices_array , config .mesh_axes )
72+ num_devices = len (jax .devices ())
73+ device_mesh = np .array (jax .devices ()).reshape ((1 , num_devices ))
74+ mesh = Mesh (device_mesh , axis_names = ('embed' , 'mlp' ))
8175
8276 with mesh :
8377 layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
@@ -87,16 +81,9 @@ def test_nnx_pixart_alpha_text_projection(self):
8781 def test_nnx_timestep_embedding (self ):
8882 key = jax .random .key (0 )
8983 rngs = nnx .Rngs (key )
90- pyconfig .initialize (
91- [
92- None ,
93- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
94- ],
95- unittest = True ,
96- )
97- config = pyconfig .config
98- devices_array = create_device_mesh (config )
99- mesh = Mesh (devices_array , config .mesh_axes )
84+ num_devices = len (jax .devices ())
85+ device_mesh = np .array (jax .devices ()).reshape ((1 , num_devices ))
86+ mesh = Mesh (device_mesh , axis_names = ('embed' , 'mlp' ))
10087
10188 dummy_sample = jnp .ones ((1 , 256 ))
10289 with mesh :
0 commit comments