3030from transformers import (
3131 CLIPTokenizer ,
3232 FlaxCLIPTextModel ,
33- T5TokenizerFast ,
3433 T5EncoderModel ,
35- FlaxT5EncoderModel
34+ FlaxT5EncoderModel ,
35+ AutoTokenizer
3636)
3737
3838from maxdiffusion import FlaxAutoencoderKL , pyconfig , max_logging
@@ -235,7 +235,7 @@ def get_clip_prompt_embeds(
235235def get_t5_prompt_embeds (
236236 prompt : Union [str , List [str ]],
237237 num_images_per_prompt : int ,
238- tokenizer : T5TokenizerFast ,
238+ tokenizer : AutoTokenizer ,
239239 text_encoder : T5EncoderModel ,
240240 max_sequence_length : int = 512
241241):
@@ -245,18 +245,20 @@ def get_t5_prompt_embeds(
245245
246246 text_inputs = tokenizer (
247247 prompt ,
248- padding = "max_length" ,
249- max_length = max_sequence_length ,
250248 truncation = True ,
249+ max_length = max_sequence_length ,
251250 return_length = False ,
252251 return_overflowing_tokens = False ,
252+ padding = "max_length" ,
253253 return_tensors = "np"
254254 )
255255 text_input_ids = text_inputs .input_ids
256- prompt_embeds = text_encoder (text_input_ids , output_hidden_states = False )[0 ]
256+ prompt_embeds = text_encoder (
257+ text_input_ids ,
258+ attention_mask = None ,
259+ output_hidden_states = False )["last_hidden_state" ]
257260 dtype = text_encoder .dtype
258261 prompt_embeds = prompt_embeds .astype (dtype )
259-
260262 _ , seq_len , _ = prompt_embeds .shape
261263 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262264 prompt_embeds = jnp .tile (prompt_embeds , (1 , num_images_per_prompt , 1 ))
@@ -270,7 +272,7 @@ def encode_prompt(
270272 prompt_2 : Union [str , List [str ]],
271273 clip_tokenizer : CLIPTokenizer ,
272274 clip_text_encoder : FlaxCLIPTextModel ,
273- t5_tokenizer : T5TokenizerFast ,
275+ t5_tokenizer : AutoTokenizer ,
274276 t5_text_encoder : T5EncoderModel ,
275277 num_images_per_prompt : int = 1 ,
276278 max_sequence_length : int = 512
@@ -368,13 +370,10 @@ def run(config):
368370 )
369371
370372 t5_encoder = FlaxT5EncoderModel .from_pretrained (
371- config .clip_model_name_or_path ,
373+ config .t5xxl_model_name_or_path ,
372374 dtype = config .weights_dtype
373375 )
374- t5_tokenizer = T5TokenizerFast .from_pretrained (
375- config .pretrained_model_name_or_path ,
376- subfolder = "tokenizer_2" ,
377- )
376+ t5_tokenizer = AutoTokenizer .from_pretrained (config .t5xxl_model_name_or_path , max_length = config .max_sequence_length , use_fast = True )
378377
379378 encoders_sharding = PositionalSharding (devices_array ).replicate ()
380379 partial_device_put_replicated = functools .partial (device_put_replicated , sharding = encoders_sharding )
@@ -405,9 +404,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
405404 timesteps = jnp .asarray ([1.0 ] * global_batch_size , dtype = jnp .bfloat16 )
406405 guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
407406
408- # TODO - remove this later and figure out why t5x is returning wrong shape
409- prompt_embeds = jnp .ones ((global_batch_size , 512 , 4096 ))
410-
411407 validate_inputs (
412408 latents ,
413409 latent_image_ids ,
@@ -418,8 +414,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
418414 pooled_prompt_embeds
419415 )
420416
421-
422-
423417 # move inputs to device and shard
424418 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
425419 latents = jax .device_put (latents , data_sharding )
@@ -430,6 +424,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
430424 guidance = jax .device_put (guidance , data_sharding )
431425 pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
432426
427+ if config .offload_encoders :
428+ cpus = jax .devices ("cpu" )
429+ t5_encoder .params = jax .device_put (t5_encoder .params , device = cpus [0 ])
430+
433431 get_memory_allocations ()
434432 # evaluate shapes
435433 transformer_eval_params = transformer .init_weights (rngs = rng , max_sequence_length = 512 , eval_only = True )
@@ -444,11 +442,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
444442 config = config ,
445443 mesh = mesh ,
446444 weights_init_fn = weights_init_fn ,
447- model_params = transformer_params ,
448- #model_params=None,
445+ model_params = None ,
449446 training = False
450447 )
451- #transformer_state = transformer_state.replace(params=transformer_params)
448+ transformer_state = transformer_state .replace (params = transformer_params )
449+ transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
452450 get_memory_allocations ()
453451
454452 states = {}
0 commit comments