2626from ...models .wan .transformers .transformer_wan import WanModel
2727from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan , AutoencoderKLWanCache
2828from maxdiffusion .video_processor import VideoProcessor
29- from ...utils import export_to_video
3029from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler , UniPCMultistepSchedulerState
3130from transformers import AutoTokenizer , UMT5EncoderModel
3231import ftfy
@@ -314,75 +313,77 @@ def __call__(
314313 max_sequence_length : int = 512 ,
315314 latents : jax .Array = None ,
316315 prompt_embeds : jax .Array = None ,
317- negative_prompt_embeds : jax .Array = None
316+ negative_prompt_embeds : jax .Array = None ,
317+ vae_only : bool = False
318318 ):
319- if num_frames % self .vae_scale_factor_temporal != 1 :
320- max_logging .log (
321- f"`num_frames -1` has to be divisible by { self .vae_scale_factor_temporal } . Rounding to the nearest number."
319+ if not vae_only :
320+ if num_frames % self .vae_scale_factor_temporal != 1 :
321+ max_logging .log (
322+ f"`num_frames -1` has to be divisible by { self .vae_scale_factor_temporal } . Rounding to the nearest number."
323+ )
324+ num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
325+ num_frames = max (num_frames , 1 )
326+
327+ # 2. Define call parameters
328+ if prompt is not None and isinstance (prompt , str ):
329+ batch_size = 1
330+ elif prompt is not None and isinstance (prompt , list ):
331+ batch_size = len (prompt )
332+
333+ prompt_embeds , negative_prompt_embeds = self .encode_prompt (
334+ prompt = prompt ,
335+ negative_prompt = negative_prompt ,
336+ max_sequence_length = max_sequence_length ,
337+ prompt_embeds = prompt_embeds ,
338+ negative_prompt_embeds = negative_prompt_embeds
322339 )
323- num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
324- num_frames = max (num_frames , 1 )
325-
326- # 2. Define call parameters
327- if prompt is not None and isinstance (prompt , str ):
328- batch_size = 1
329- elif prompt is not None and isinstance (prompt , list ):
330- batch_size = len (prompt )
331-
332- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
333- prompt = prompt ,
334- negative_prompt = negative_prompt ,
335- max_sequence_length = max_sequence_length ,
336- prompt_embeds = prompt_embeds ,
337- negative_prompt_embeds = negative_prompt_embeds
338- )
339340
340- num_channel_latents = self .transformer .config .in_channels
341- if latents is None :
342- latents = self .prepare_latents (
343- batch_size = batch_size ,
344- vae_scale_factor_temporal = self .vae_scale_factor_temporal ,
345- vae_scale_factor_spatial = self .vae_scale_factor_spatial ,
346- height = height ,
347- width = width ,
348- num_frames = num_frames ,
349- num_channels_latents = num_channel_latents
341+ num_channel_latents = self .transformer .config .in_channels
342+ if latents is None :
343+ latents = self .prepare_latents (
344+ batch_size = batch_size ,
345+ vae_scale_factor_temporal = self .vae_scale_factor_temporal ,
346+ vae_scale_factor_spatial = self .vae_scale_factor_spatial ,
347+ height = height ,
348+ width = width ,
349+ num_frames = num_frames ,
350+ num_channels_latents = num_channel_latents
351+ )
352+
353+ prompt_embeds = jnp .concatenate ([prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
354+ negative_prompt_embeds = jnp .concatenate ([negative_prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
355+
356+ latents = jax .device_put (latents , PositionalSharding (self .devices_array ).replicate ())
357+ prompt_embeds = jax .device_put (prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
358+ negative_prompt_embeds = jax .device_put (negative_prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
359+
360+ scheduler_state = self .scheduler .set_timesteps (
361+ self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
350362 )
351363
352- prompt_embeds = jnp .concatenate ([prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
353- negative_prompt_embeds = jnp .concatenate ([negative_prompt_embeds ] * latents .shape [0 ], dtype = self .config .weights_dtype )
354-
355- latents = jax .device_put (latents , PositionalSharding (self .devices_array ).replicate ())
356- prompt_embeds = jax .device_put (prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
357- negative_prompt_embeds = jax .device_put (negative_prompt_embeds , PositionalSharding (self .devices_array ).replicate ())
358-
359- scheduler_state = self .scheduler .set_timesteps (
360- self .scheduler_state , num_inference_steps = self .config .num_inference_steps , shape = latents .shape
361- )
362-
363- graphdef , state , rest_of_state = nnx .split (self .transformer , nnx .Param , ...)
364+ graphdef , state , rest_of_state = nnx .split (self .transformer , nnx .Param , ...)
364365
365- p_run_inference = partial (
366- run_inference ,
367- guidance_scale = self .config .guidance_scale ,
368- num_inference_steps = self .config .num_inference_steps ,
369- scheduler = self .scheduler ,
370- scheduler_state = scheduler_state
371- )
372- with self .mesh :
373- latents = p_run_inference (
374- graphdef = graphdef ,
375- sharded_state = state ,
376- rest_of_state = rest_of_state ,
377- latents = latents ,
378- prompt_embeds = prompt_embeds ,
379- negative_prompt_embeds = negative_prompt_embeds
366+ p_run_inference = partial (
367+ run_inference ,
368+ guidance_scale = guidance_scale ,
369+ num_inference_steps = num_inference_steps ,
370+ scheduler = self .scheduler ,
371+ scheduler_state = scheduler_state
380372 )
381- latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , 1 , 1 , 1 , self .vae .z_dim )
382- latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , 1 , 1 , 1 , self .vae .z_dim )
383- latents = latents / latents_std + latents_mean
384-
385- latents = latents .astype (self .config .weights_dtype )
373+ with self .mesh :
374+ latents = p_run_inference (
375+ graphdef = graphdef ,
376+ sharded_state = state ,
377+ rest_of_state = rest_of_state ,
378+ latents = latents ,
379+ prompt_embeds = prompt_embeds ,
380+ negative_prompt_embeds = negative_prompt_embeds
381+ )
382+ latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , 1 , 1 , 1 , self .vae .z_dim )
383+ latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , 1 , 1 , 1 , self .vae .z_dim )
384+ latents = latents / latents_std + latents_mean
385+
386+ latents = latents .astype (self .config .weights_dtype )
386387
387388 jitted_decode = jax .jit (
388389 partial (
@@ -396,9 +397,18 @@ def __call__(
396397 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
397398 video = torch .from_numpy (np .array (video .astype (dtype = jnp .float32 ))).to (dtype = torch .bfloat16 )
398399 video = self .video_processor .postprocess_video (video , output_type = "np" )
399- export_to_video (video [0 ], "jax_output.mp4" , fps = 24 )
400+ return video
401+
402+
403+ @jax .jit
404+ def transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , prompt_embeds ):
405+ wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
406+ return wan_transformer (
407+ hidden_states = latents ,
408+ timestep = timestep ,
409+ encoder_hidden_states = prompt_embeds
410+ )[0 ]
400411
401-
402412#@partial(jax.jit, static_argnums=(6, 7, 8))
403413def run_inference (
404414 graphdef ,
@@ -411,26 +421,16 @@ def run_inference(
411421 num_inference_steps : int ,
412422 scheduler : FlaxUniPCMultistepScheduler ,
413423 scheduler_state ):
414- wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
415424 do_classifier_free_guidance = guidance_scale > 1.0
416425 for step in range (num_inference_steps ):
417426 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
418427 timestep = jnp .broadcast_to (t , latents .shape [0 ])
419-
420- noise_pred = wan_transformer (
421- hidden_states = latents ,
422- timestep = timestep ,
423- encoder_hidden_states = prompt_embeds ,
424- return_dict = False
425- )[0 ]
428+
429+ noise_pred = transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , prompt_embeds )
426430
427431 if do_classifier_free_guidance :
428- noise_uncond = wan_transformer (
429- hidden_states = latents ,
430- timestep = timestep ,
431- encoder_hidden_states = negative_prompt_embeds ,
432- return_dict = False
433- )[0 ]
432+ noise_uncond = transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , negative_prompt_embeds )
434433 noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
434+
435435 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
436436 return latents
0 commit comments