1919import functools
2020import numpy as np
2121import jax
22- from jax .sharding import Mesh , PositionalSharding
22+ from jax .sharding import Mesh , PositionalSharding , PartitionSpec as P
2323import jax .numpy as jnp
2424from chex import Array
2525from transformers import (
@@ -196,7 +196,7 @@ def run(config):
196196 devices_array = create_device_mesh (config )
197197 mesh = Mesh (devices_array , config .mesh_axes )
198198
199- per_host_number_of_images = config .per_device_batch_size * jax .local_device_count ()
199+ global_batch_size = config .per_device_batch_size * jax .local_device_count ()
200200
201201 # LOAD VAE
202202
@@ -225,7 +225,7 @@ def run(config):
225225
226226 num_channels_latents = transformer .in_channels // 4
227227 latents , latent_image_ids = prepare_latents (
228- batch_size = per_host_number_of_images ,
228+ batch_size = global_batch_size ,
229229 num_channels_latents = num_channels_latents ,
230230 height = config .resolution ,
231231 width = config .resolution ,
@@ -270,7 +270,7 @@ def run(config):
270270 clip_text_encoder = clip_text_encoder ,
271271 t5_tokenizer = t5_tokenizer ,
272272 t5_text_encoder = t5_encoder ,
273- num_images_per_prompt = per_host_number_of_images
273+ num_images_per_prompt = global_batch_size
274274 )
275275
276276 def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
@@ -282,8 +282,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
282282 print ("guidance.shape: " , guidance .shape , guidance .dtype )
283283 print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds .dtype )
284284
285- timesteps = jnp .asarray ([1.0 ], dtype = jnp .bfloat16 )
286- guidance = jnp .asarray ([3.5 ], dtype = jnp .bfloat16 )
285+ timesteps = jnp .asarray ([1.0 ] * global_batch_size , dtype = jnp .bfloat16 )
286+ guidance = jnp .asarray ([3.5 ] * global_batch_size , dtype = jnp .bfloat16 )
287287 validate_inputs (
288288 latents ,
289289 latent_image_ids ,
@@ -293,13 +293,26 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
293293 guidance ,
294294 pooled_prompt_embeds
295295 )
296+
297+ # TODO - remove this later and figure out why t5x is returning wrong shape
298+ prompt_embeds = jnp .ones ((global_batch_size , 512 , 4096 ))
299+
300+ # move inputs to device and shard
301+ data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
302+ latents = jax .device_put (latents , data_sharding )
303+ latent_image_ids = jax .device_put (latent_image_ids , data_sharding )
304+ prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
305+ text_ids = jax .device_put (text_ids , data_sharding )
306+ timesteps = jax .device_put (timesteps , data_sharding )
307+ guidance = jax .device_put (guidance , data_sharding )
308+ pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
309+
296310 get_memory_allocations ()
297311 # evaluate shapes
298312 transformer_eval_params = transformer .init_weights (rngs = rng , max_sequence_length = 512 , eval_only = True )
299313
300314 # loads pretrained weights
301315 transformer_params = load_flow_model ("flux-dev" , transformer_eval_params , "cpu" )
302- get_memory_allocations ()
303316 # create transformer state
304317 weights_init_fn = functools .partial (transformer .init_weights , rngs = rng , max_sequence_length = 512 , eval_only = False )
305318 transformer_state , transformer_state_shardings = setup_initial_state (
@@ -308,24 +321,35 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
308321 config = config ,
309322 mesh = mesh ,
310323 weights_init_fn = weights_init_fn ,
311- model_params = None ,
324+ model_params = transformer_params ,
312325 training = False
313326 )
314- breakpoint ()
315- transformer_state = transformer_state .replace (params = transformer_params )
316- img = transformer .apply (
317- {"params" : transformer_state .params },
318- img = latents ,
319- img_ids = latent_image_ids ,
320- txt = prompt_embeds ,
321- txt_ids = text_ids ,
322- timesteps = timesteps ,
323- guidance = guidance ,
324- y = pooled_prompt_embeds
325- )
327+ #transformer_state = transformer_state.replace(params=transformer_params)
326328 get_memory_allocations ()
327- breakpoint ()
329+ def run_inference (state , transformer ):
330+ img = transformer .apply (
331+ {"params" : state .params },
332+ img = latents ,
333+ img_ids = latent_image_ids ,
334+ txt = prompt_embeds ,
335+ txt_ids = text_ids ,
336+ timesteps = timesteps ,
337+ guidance = guidance ,
338+ y = pooled_prompt_embeds
339+ )
340+ return img
341+
342+ p_run_inference = jax .jit (
343+ functools .partial (
344+ run_inference ,
345+ transformer = transformer
346+ ),
347+ in_shardings = (transformer_state_shardings ,),
348+ out_shardings = None
349+ )
328350
351+ img = p_run_inference (transformer_state )
352+ print ("img.shape: " , img .shape )
329353
330354
331355def main (argv : Sequence [str ]) -> None :
0 commit comments