@@ -570,21 +570,22 @@ def convert_weight(pt_key_base, jax_key):
570570
571571 # 4. Run Forward
572572 print ("Running MaxDiffusion forward pass..." )
573- output = model (
574- hidden_states = jax_inputs ["hidden_states" ],
575- audio_hidden_states = jax_inputs ["audio_hidden_states" ],
576- encoder_hidden_states = jax_inputs ["encoder_hidden_states" ],
577- audio_encoder_hidden_states = jax_inputs ["audio_encoder_hidden_states" ],
578- timestep = jax_inputs ["timestep" ],
579- encoder_attention_mask = jax_inputs ["encoder_attention_mask" ],
580- audio_encoder_attention_mask = jax_inputs ["audio_encoder_attention_mask" ],
581- num_frames = config ["num_frames" ] if "num_frames" in config else 4 ,
582- height = config ["height" ] if "height" in config else 32 ,
583- width = config ["width" ] if "width" in config else 32 ,
584- audio_num_frames = 128 ,
585- fps = 24.0 ,
586- return_dict = True ,
587- )
573+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
574+ output = model (
575+ hidden_states = jax_inputs ["hidden_states" ],
576+ audio_hidden_states = jax_inputs ["audio_hidden_states" ],
577+ encoder_hidden_states = jax_inputs ["encoder_hidden_states" ],
578+ audio_encoder_hidden_states = jax_inputs ["audio_encoder_hidden_states" ],
579+ timestep = jax_inputs ["timestep" ],
580+ encoder_attention_mask = jax_inputs ["encoder_attention_mask" ],
581+ audio_encoder_attention_mask = jax_inputs ["audio_encoder_attention_mask" ],
582+ num_frames = config ["num_frames" ] if "num_frames" in config else 4 ,
583+ height = config ["height" ] if "height" in config else 32 ,
584+ width = config ["width" ] if "width" in config else 32 ,
585+ audio_num_frames = 128 ,
586+ fps = 24.0 ,
587+ return_dict = True ,
588+ )
588589
589590 max_sample = output ["sample" ]
590591 max_audio_sample = output ["audio_sample" ]
0 commit comments