Skip to content

Commit 93a3bb6

Browse files
e2e, encoder offloading.
1 parent 860e76e commit 93a3bb6

3 files changed

Lines changed: 24 additions & 22 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ max_sequence_length: 512
3232
time_shift: False
3333
base_shift: 0.5
3434
max_shift: 1.15
35+
# offloads t5 encoder after text encoding to save memory.
36+
offload_encoders: True
3537

3638

3739
unet_checkpoint: ''
@@ -210,7 +212,7 @@ do_classifier_free_guidance: True
210212
guidance_scale: 3.5
211213
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
212214
guidance_rescale: 0.0
213-
num_inference_steps: 20
215+
num_inference_steps: 50
214216

215217
# SDXL Lightning parameters
216218
lightning_from_pt: True

src/maxdiffusion/configs/base_fux_schnell.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ max_sequence_length: 256
3232
time_shift: False
3333
base_shift: 0.5
3434
max_shift: 1.15
35+
# offloads t5 encoder after text encoding to save memory.
36+
offload_encoders: True
3537

3638
unet_checkpoint: ''
3739
revision: 'refs/pr/95'

src/maxdiffusion/generate_flux.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
from transformers import (
3131
CLIPTokenizer,
3232
FlaxCLIPTextModel,
33-
T5TokenizerFast,
3433
T5EncoderModel,
35-
FlaxT5EncoderModel
34+
FlaxT5EncoderModel,
35+
AutoTokenizer
3636
)
3737

3838
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
@@ -235,7 +235,7 @@ def get_clip_prompt_embeds(
235235
def get_t5_prompt_embeds(
236236
prompt: Union[str, List[str]],
237237
num_images_per_prompt: int,
238-
tokenizer: T5TokenizerFast,
238+
tokenizer: AutoTokenizer,
239239
text_encoder: T5EncoderModel,
240240
max_sequence_length: int = 512
241241
):
@@ -245,18 +245,20 @@ def get_t5_prompt_embeds(
245245

246246
text_inputs = tokenizer(
247247
prompt,
248-
padding="max_length",
249-
max_length=max_sequence_length,
250248
truncation=True,
249+
max_length=max_sequence_length,
251250
return_length=False,
252251
return_overflowing_tokens=False,
252+
padding="max_length",
253253
return_tensors="np"
254254
)
255255
text_input_ids = text_inputs.input_ids
256-
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False)[0]
256+
prompt_embeds = text_encoder(
257+
text_input_ids,
258+
attention_mask=None,
259+
output_hidden_states=False)["last_hidden_state"]
257260
dtype = text_encoder.dtype
258261
prompt_embeds = prompt_embeds.astype(dtype)
259-
260262
_, seq_len, _ = prompt_embeds.shape
261263
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262264
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
@@ -270,7 +272,7 @@ def encode_prompt(
270272
prompt_2: Union[str, List[str]],
271273
clip_tokenizer: CLIPTokenizer,
272274
clip_text_encoder: FlaxCLIPTextModel,
273-
t5_tokenizer: T5TokenizerFast,
275+
t5_tokenizer: AutoTokenizer,
274276
t5_text_encoder: T5EncoderModel,
275277
num_images_per_prompt: int = 1,
276278
max_sequence_length: int = 512
@@ -368,13 +370,10 @@ def run(config):
368370
)
369371

370372
t5_encoder = FlaxT5EncoderModel.from_pretrained(
371-
config.clip_model_name_or_path,
373+
config.t5xxl_model_name_or_path,
372374
dtype=config.weights_dtype
373375
)
374-
t5_tokenizer = T5TokenizerFast.from_pretrained(
375-
config.pretrained_model_name_or_path,
376-
subfolder="tokenizer_2",
377-
)
376+
t5_tokenizer = AutoTokenizer.from_pretrained(config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True)
378377

379378
encoders_sharding = PositionalSharding(devices_array).replicate()
380379
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
@@ -405,9 +404,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
405404
timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
406405
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
407406

408-
# TODO - remove this later and figure out why t5x is returning wrong shape
409-
prompt_embeds = jnp.ones((global_batch_size, 512, 4096))
410-
411407
validate_inputs(
412408
latents,
413409
latent_image_ids,
@@ -418,8 +414,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
418414
pooled_prompt_embeds
419415
)
420416

421-
422-
423417
# move inputs to device and shard
424418
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
425419
latents = jax.device_put(latents, data_sharding)
@@ -430,6 +424,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
430424
guidance = jax.device_put(guidance, data_sharding)
431425
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
432426

427+
if config.offload_encoders:
428+
cpus = jax.devices("cpu")
429+
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])
430+
433431
get_memory_allocations()
434432
# evaluate shapes
435433
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True)
@@ -444,11 +442,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
444442
config=config,
445443
mesh=mesh,
446444
weights_init_fn=weights_init_fn,
447-
model_params=transformer_params,
448-
#model_params=None,
445+
model_params=None,
449446
training=False
450447
)
451-
#transformer_state = transformer_state.replace(params=transformer_params)
448+
transformer_state = transformer_state.replace(params=transformer_params)
449+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
452450
get_memory_allocations()
453451

454452
states = {}

0 commit comments

Comments
 (0)