@@ -530,6 +530,7 @@ def _load_and_init(cls, config: HyperParameters, restored_checkpoint, vae_only=F
530530 vocoder = components ["vocoder" ]
531531 )
532532 pipeline .mesh = components ["mesh" ]
533+ pipeline .config = config
533534 if load_transformer :
534535 pipeline .transformer = cls .quantize_transformer (config , pipeline .transformer , pipeline , pipeline .mesh )
535536 return pipeline , pipeline .transformer
@@ -1204,6 +1205,16 @@ def __call__(
12041205 prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
12051206 latents_jax = jnp .concatenate ([latents_jax ] * 2 , axis = 0 )
12061207 audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 2 , axis = 0 )
1208+
1209+ if hasattr (self , "mesh" ) and self .mesh is not None :
1210+ data_sharding = NamedSharding (self .mesh , P ())
1211+ if hasattr (self , "config" ) and hasattr (self .config , "data_sharding" ):
1212+ data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
1213+ if isinstance (prompt_embeds_jax , list ):
1214+ prompt_embeds_jax = [jax .device_put (x , data_sharding ) for x in prompt_embeds_jax ]
1215+ else :
1216+ prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding )
1217+ prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding )
12071218
12081219 # GraphDef and State
12091220 graphdef , state = nnx .split (self .transformer )
0 commit comments