@@ -1207,14 +1207,16 @@ def __call__(
12071207 audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 2 , axis = 0 )
12081208
12091209 if hasattr (self , "mesh" ) and self .mesh is not None :
1210- data_sharding = NamedSharding (self .mesh , P ())
1210+ data_sharding_3d = NamedSharding (self .mesh , P ())
1211+ data_sharding_2d = NamedSharding (self .mesh , P ())
12111212 if hasattr (self , "config" ) and hasattr (self .config , "data_sharding" ):
1212- data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
1213+ data_sharding_3d = NamedSharding (self .mesh , P (* self .config .data_sharding [:3 ]))
1214+ data_sharding_2d = NamedSharding (self .mesh , P (* self .config .data_sharding [:2 ]))
12131215 if isinstance (prompt_embeds_jax , list ):
1214- prompt_embeds_jax = [jax .device_put (x , data_sharding ) for x in prompt_embeds_jax ]
1216+ prompt_embeds_jax = [jax .device_put (x , data_sharding_3d ) for x in prompt_embeds_jax ]
12151217 else :
1216- prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding )
1217- prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding )
1218+ prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding_3d )
1219+ prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding_2d )
12181220
12191221 # GraphDef and State
12201222 graphdef , state = nnx .split (self .transformer )
0 commit comments