Skip to content

Commit f0a1904

Browse files
jfacevedo-googleksikiric
authored andcommitted
batch text encoding.
1 parent bffc7dc commit f0a1904

4 files changed

Lines changed: 32 additions & 10 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ diffusion_scheduler_config: {
111111
base_output_directory: ""
112112

113113
# Hardware
114-
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
114+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
115115

116116
# Parallelism
117117
mesh_axes: ['data', 'fsdp', 'tensor']

src/maxdiffusion/pipelines/flux/flux_pipeline.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def get_t5_prompt_embeds(
190190
tokenizer: AutoTokenizer,
191191
text_encoder: FlaxT5EncoderModel,
192192
max_sequence_length: int = 512,
193+
encode_in_batches=False,
194+
encode_batch_size=None,
193195
):
194196

195197
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -205,13 +207,23 @@ def get_t5_prompt_embeds(
205207
return_tensors="np",
206208
)
207209
text_input_ids = text_inputs.input_ids
208-
prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"]
210+
if encode_in_batches:
211+
prompt_embeds = None
212+
for i in range(0, text_input_ids.shape[0], encode_batch_size):
213+
batch_prompt_embeds = text_encoder(text_input_ids[i:i+encode_batch_size], attention_mask=None, output_hidden_states=False)["last_hidden_state"]
214+
if prompt_embeds == None:
215+
prompt_embeds = batch_prompt_embeds
216+
else:
217+
prompt_embeds = jnp.concatenate([prompt_embeds, batch_prompt_embeds])
218+
else:
219+
prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"]
220+
_, seq_len, _ = prompt_embeds.shape
221+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
222+
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
223+
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))
224+
209225
dtype = text_encoder.dtype
210226
prompt_embeds = prompt_embeds.astype(dtype)
211-
_, seq_len, _ = prompt_embeds.shape
212-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
213-
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
214-
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))
215227

216228
return prompt_embeds
217229

@@ -226,7 +238,12 @@ def encode_prompt(
226238
t5_text_encoder: FlaxT5EncoderModel,
227239
num_images_per_prompt: int = 1,
228240
max_sequence_length: int = 512,
241+
encode_in_batches: bool = False,
242+
encode_batch_size: int = None
229243
):
244+
245+
if encode_in_batches:
246+
assert encode_in_batches is not None
230247

231248
prompt = [prompt] if isinstance(prompt, str) else prompt
232249
prompt_2 = prompt or prompt_2
@@ -242,6 +259,8 @@ def encode_prompt(
242259
tokenizer=t5_tokenizer,
243260
text_encoder=t5_text_encoder,
244261
max_sequence_length=max_sequence_length,
262+
encode_in_batches=encode_in_batches,
263+
encode_batch_size=encode_batch_size
245264
)
246265

247266
text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)

src/maxdiffusion/train_flux.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@
2424
mllog_utils,
2525
)
2626

27-
from maxdiffusion.trainers.flux_trainer import FluxTrainer
28-
2927
from maxdiffusion.train_utils import (
3028
validate_train_config,
3129
)
3230

3331

3432
def train(config):
33+
from maxdiffusion.trainers.flux_trainer import FluxTrainer
3534
trainer = FluxTrainer(config)
3635
trainer.start_training()
3736

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def start_training(self):
105105
if self.config.dataset_type == "grain":
106106
data_iterator = self.restore_data_iterator_state(data_iterator)
107107

108+
# don't need this anymore, clear some memory.
109+
del pipeline.t5_encoder
108110

109111
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
110112
# ambiguous here, but if self.params.get("unet") doesn't exist
@@ -138,7 +140,7 @@ def start_training(self):
138140
)
139141
# 6. save final checkpoint
140142
# Hook
141-
self.post_training_steps(pipeline, params, train_states, "after_training")
143+
#self.post_training_steps(pipeline, params, train_states, "after_training")
142144

143145
def get_shaped_batch(self, config, pipeline=None):
144146
"""Return the shape of the batch - this is what eval_shape would return for the
@@ -267,7 +269,9 @@ def load_dataset(self, pipeline, params, train_states):
267269
clip_tokenizer=pipeline.clip_tokenizer,
268270
t5_tokenizer=pipeline.t5_tokenizer,
269271
clip_text_encoder=pipeline.clip_encoder,
270-
t5_text_encoder=pipeline.t5_encoder
272+
t5_text_encoder=pipeline.t5_encoder,
273+
encode_in_batches=True,
274+
encode_batch_size=16
271275
)
272276
pack_latents_p = partial(pipeline.pack_latents)
273277
prepare_latent_image_ids_p = partial(pipeline.prepare_latent_image_ids)

0 commit comments

Comments
 (0)