1717import numpy as np
1818import jax
1919import jax .numpy as jnp
20- from jax .sharding import Mesh , PositionalSharding
20+ from jax .sharding import Mesh , PositionalSharding , PartitionSpec as P
2121import flax
2222import flax .linen as nn
2323from flax import nnx
2424from ...pyconfig import HyperParameters
2525from ... import max_logging
2626from ... import max_utils
27+ from ...max_utils import get_flash_block_sizes , get_precision
2728from ...models .wan .wan_utils import load_wan_transformer , load_wan_vae
2829from ...models .wan .transformers .transformer_wan import WanModel
2930from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan , AutoencoderKLWanCache
@@ -59,11 +60,12 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
5960
6061partial (nnx .jit , static_argnums = (3 ,))
6162def create_sharded_logical_transformer (devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
62- # breakpoint()
63+
6364 def create_model (rngs : nnx .Rngs , wan_config : dict ):
6465 wan_transformer = WanModel (** wan_config , rngs = rngs )
6566 return wan_transformer
6667
68+ # 1. Load config.
6769 wan_config = WanModel .load_config (
6870 config .pretrained_model_name_or_path ,
6971 subfolder = "transformer"
@@ -72,32 +74,39 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7274 wan_config ["dtype" ] = config .activations_dtype
7375 wan_config ["weights_dtype" ] = config .weights_dtype
7476 wan_config ["attention" ] = config .attention
77+ wan_config ["precision" ] = get_precision (config )
78+ wan_config ["flash_block_sizes" ] = get_flash_block_sizes (config )
79+
80+ # 2. eval_shape - will not use flops or create weights on device
81+ # thus not using HBM memory.
7582 p_model_factory = partial (create_model , wan_config = wan_config )
7683 wan_transformer = nnx .eval_shape (p_model_factory , rngs = rngs )
7784 graphdef , state , rest_of_state = nnx .split (wan_transformer , nnx .Param , ...)
78- #breakpoint()
85+
86+ # 3. retrieve the state shardings, mapping logical names to mesh axis names.
7987 logical_state_spec = nnx .get_partition_spec (state )
8088 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
8189 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
8290 params = state .to_pure_dict ()
8391 state = dict (nnx .to_flat_state (state ))
84- # del state
92+
93+ # 4. Load pretrained weights and move them to device using the state shardings from (3) above.
94+ # This helps with loading sharded weights directly into the accelerators without fist copying them
95+ # all to one device and then distributing them, thus using low HBM memory.
8596 params = load_wan_transformer (config .pretrained_model_name_or_path , params , "cpu" )
8697 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
8798 for path , val in flax .traverse_util .flatten_dict (params ).items ():
8899 sharding = logical_state_sharding [path ].value
89- state [path ].value = jax .device_put (val , sharding )
100+ try :
101+ state [path ].value = jax .device_put (val , sharding )
102+ except :
103+ breakpoint ()
90104 state = nnx .from_flat_state (state )
91- p_add_sharding_rule = partial (_add_sharding_rule , logical_axis_rules = config .logical_axis_rules )
92- state = jax .tree .map (p_add_sharding_rule , state , is_leaf = lambda x : isinstance (x , nnx .VariableState ))
93- pspecs = nnx .get_partition_spec (state )
94- #breakpoint()
95- sharded_state = jax .lax .with_sharding_constraint (state , pspecs )
96- #breakpoint()
97- #wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state)
98- wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
105+
106+ wan_transformer = nnx .merge (graphdef , state , rest_of_state )
99107 return wan_transformer
100108
109+
101110partial (nnx .jit , static_argnums = (1 ,))
102111def create_sharded_logical_model (model , logical_axis_rules ):
103112 graphdef , state , rest_of_state = nnx .split (model , nnx .Param , ...)
@@ -108,6 +117,7 @@ def create_sharded_logical_model(model, logical_axis_rules):
108117 wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
109118 return wan_transformer
110119
120+
111121class WanPipeline :
112122 r"""
113123 Pipeline for text-to-video generation using Wan.
@@ -155,6 +165,7 @@ def __init__(
155165
156166 self .p_run_inference = None
157167
168+
158169 @classmethod
159170 def load_text_encoder (cls , config : HyperParameters ):
160171 text_encoder = UMT5EncoderModel .from_pretrained (
@@ -163,6 +174,7 @@ def load_text_encoder(cls, config: HyperParameters):
163174 )
164175 return text_encoder
165176
177+
166178 @classmethod
167179 def load_tokenizer (cls , config : HyperParameters ):
168180 tokenizer = AutoTokenizer .from_pretrained (
@@ -171,6 +183,7 @@ def load_tokenizer(cls, config: HyperParameters):
171183 )
172184 return tokenizer
173185
186+
174187 @classmethod
175188 def load_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
176189 wan_vae = AutoencoderKLWan .from_config (
@@ -196,33 +209,14 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
196209 wan_vae = p_create_sharded_logical_model (model = wan_vae )
197210 return wan_vae , vae_cache
198211
212+
199213 @classmethod
200214 def load_transformer (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
201215 with mesh :
202216 wan_transformer = create_sharded_logical_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
203- # wan_transformer = WanModel.from_config(
204- # config.pretrained_model_name_or_path,
205- # subfolder="transformer",
206- # rngs=rngs,
207- # attention=config.attention,
208- # mesh=mesh,
209- # dtype=config.activations_dtype,
210- # weights_dtype=config.weights_dtype
211- # )
212- # graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
213- # breakpoint()
214- # params = state.to_pure_dict()
215- # del state
216- # #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
217- # params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
218- # #params = jax.device_put(params, PositionalSharding(devices_array).replicate())
219- # wan_transformer = nnx.merge(graphdef, params, rest_of_state)
220- # # Shard
221- # p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
222- # with mesh:
223- # wan_transformer = p_create_sharded_logical_model(model=wan_transformer)
224217 return wan_transformer
225218
219+
226220 @classmethod
227221 def load_scheduler (cls , config ):
228222 scheduler , scheduler_state = FlaxUniPCMultistepScheduler .from_pretrained (
@@ -232,6 +226,7 @@ def load_scheduler(cls, config):
232226 )
233227 return scheduler , scheduler_state
234228
229+
235230 @classmethod
236231 def from_pretrained (cls , config : HyperParameters , vae_only = False ):
237232 devices_array = max_utils .create_device_mesh (config )
@@ -268,6 +263,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False):
268263 config = config
269264 )
270265
266+
271267 def _get_t5_prompt_embeds (
272268 self ,
273269 prompt : Union [str , List [str ]] = None ,
@@ -302,6 +298,7 @@ def _get_t5_prompt_embeds(
302298
303299 return prompt_embeds
304300
301+
305302 def encode_prompt (
306303 self ,
307304 prompt : Union [str , List [str ]],
@@ -333,6 +330,7 @@ def encode_prompt(
333330
334331 return prompt_embeds , negative_prompt_embeds
335332
333+
336334 def prepare_latents (
337335 self ,
338336 batch_size : int ,
@@ -356,6 +354,7 @@ def prepare_latents(
356354
357355 return latents
358356
357+
359358 def __call__ (
360359 self ,
361360 prompt : Union [str , List [str ]] = None ,
@@ -382,9 +381,9 @@ def __call__(
382381
383382 # 2. Define call parameters
384383 if prompt is not None and isinstance (prompt , str ):
385- batch_size = 1
386- elif prompt is not None and isinstance ( prompt , list ):
387- batch_size = len (prompt )
384+ prompt = [ prompt ]
385+
386+ batch_size = len (prompt )
388387
389388 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
390389 prompt = prompt ,
@@ -406,12 +405,13 @@ def __call__(
406405 num_channels_latents = num_channel_latents
407406 )
408407
409- prompt_embeds = jnp .concatenate ([prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
410- negative_prompt_embeds = jnp .concatenate ([negative_prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
411-
412- latents = jax .device_put (latents , PositionalSharding (self .devices_array ).replicate ())
413- prompt_embeds = jax .device_put (prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
414- negative_prompt_embeds = jax .device_put (negative_prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
408+ data_sharding = PositionalSharding (self .devices_array ).replicate ()
409+ if len (prompt ) % jax .device_count () == 0 :
410+ data_sharding = jax .sharding .NamedSharding (self .mesh , P (* self .config .data_sharding ))
411+
412+ latents = jax .device_put (latents , data_sharding )
413+ prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
414+ negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
415415
416416 scheduler_state = self .scheduler .set_timesteps (
417417 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
0 commit comments