@@ -190,6 +190,8 @@ def get_t5_prompt_embeds(
190190 tokenizer : AutoTokenizer ,
191191 text_encoder : FlaxT5EncoderModel ,
192192 max_sequence_length : int = 512 ,
193+ encode_in_batches = False ,
194+ encode_batch_size = None ,
193195 ):
194196
195197 prompt = [prompt ] if isinstance (prompt , str ) else prompt
@@ -205,13 +207,23 @@ def get_t5_prompt_embeds(
205207 return_tensors = "np" ,
206208 )
207209 text_input_ids = text_inputs .input_ids
208- prompt_embeds = text_encoder (text_input_ids , attention_mask = None , output_hidden_states = False )["last_hidden_state" ]
210+ if encode_in_batches :
211+ prompt_embeds = None
212+ for i in range (0 , text_input_ids .shape [0 ], encode_batch_size ):
213+ batch_prompt_embeds = text_encoder (text_input_ids [i :i + encode_batch_size ], attention_mask = None , output_hidden_states = False )["last_hidden_state" ]
214+ if prompt_embeds == None :
215+ prompt_embeds = batch_prompt_embeds
216+ else :
217+ prompt_embeds = jnp .concatenate ([prompt_embeds , batch_prompt_embeds ])
218+ else :
219+ prompt_embeds = text_encoder (text_input_ids , attention_mask = None , output_hidden_states = False )["last_hidden_state" ]
220+ _ , seq_len , _ = prompt_embeds .shape
221+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
222+ prompt_embeds = jnp .tile (prompt_embeds , (1 , num_images_per_prompt , 1 ))
223+ prompt_embeds = jnp .reshape (prompt_embeds , (batch_size * num_images_per_prompt , seq_len , - 1 ))
224+
209225 dtype = text_encoder .dtype
210226 prompt_embeds = prompt_embeds .astype (dtype )
211- _ , seq_len , _ = prompt_embeds .shape
212- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
213- prompt_embeds = jnp .tile (prompt_embeds , (1 , num_images_per_prompt , 1 ))
214- prompt_embeds = jnp .reshape (prompt_embeds , (batch_size * num_images_per_prompt , seq_len , - 1 ))
215227
216228 return prompt_embeds
217229
@@ -226,7 +238,12 @@ def encode_prompt(
226238 t5_text_encoder : FlaxT5EncoderModel ,
227239 num_images_per_prompt : int = 1 ,
228240 max_sequence_length : int = 512 ,
241+ encode_in_batches : bool = False ,
242+ encode_batch_size : int = None
229243 ):
244+
245+ if encode_in_batches :
246+ assert encode_in_batches is not None
230247
231248 prompt = [prompt ] if isinstance (prompt , str ) else prompt
232249 prompt_2 = prompt or prompt_2
@@ -242,6 +259,8 @@ def encode_prompt(
242259 tokenizer = t5_tokenizer ,
243260 text_encoder = t5_text_encoder ,
244261 max_sequence_length = max_sequence_length ,
262+ encode_in_batches = encode_in_batches ,
263+ encode_batch_size = encode_batch_size
245264 )
246265
247266 text_ids = jnp .zeros ((prompt_embeds .shape [0 ], prompt_embeds .shape [1 ], 3 )).astype (jnp .bfloat16 )
0 commit comments