3838from maxdiffusion .pyconfig import HyperParameters
3939from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
4040import qwix
41- import numpy as np
4241
4342RealQtRule = qwix .QtRule
4443
@@ -69,9 +68,17 @@ def test_nnx_pixart_alpha_text_projection(self):
6968 key = jax .random .key (0 )
7069 rngs = nnx .Rngs (key )
7170 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
72- num_devices = len (jax .devices ())
73- device_mesh = np .array (jax .devices ()).reshape ((1 , num_devices ))
74- mesh = Mesh (device_mesh , axis_names = ('embed' , 'mlp' ))
71+ pyconfig .initialize (
72+ [
73+ None ,
74+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
75+ ],
76+ unittest = True ,
77+ )
78+ config = pyconfig .config
79+ devices_array = create_device_mesh (config )
80+ mesh = Mesh (devices_array , config .mesh_axes )
81+
7582 with mesh :
7683 layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
7784 dummy_output = layer (dummy_caption )
@@ -80,9 +87,20 @@ def test_nnx_pixart_alpha_text_projection(self):
8087 def test_nnx_timestep_embedding (self ):
8188 key = jax .random .key (0 )
8289 rngs = nnx .Rngs (key )
90+ pyconfig .initialize (
91+ [
92+ None ,
93+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
94+ ],
95+ unittest = True ,
96+ )
97+ config = pyconfig .config
98+ devices_array = create_device_mesh (config )
99+ mesh = Mesh (devices_array , config .mesh_axes )
83100
84101 dummy_sample = jnp .ones ((1 , 256 ))
85- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
102+ with mesh :
103+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
86104 dummy_output = layer (dummy_sample )
87105 assert dummy_output .shape == (1 , 5120 )
88106
0 commit comments