Skip to content

Commit 3b40223

Browse files
iterate over different resolutions and store precompiled functions in dict.
1 parent 7d3dbd4 commit 3b40223

1 file changed

Lines changed: 103 additions & 78 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 103 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777

7878

7979
def vae_decode(latents, vae, state, config, resolution):
80-
img = unpack(x=latents.astype(jnp.float32), height=resolution, width=resolution)
80+
img = unpack(x=latents.astype(jnp.float32), height=resolution[1], width=resolution[0])
8181
img = img / vae.config.scaling_factor + vae.config.shift_factor
8282
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
8383
return img
@@ -322,15 +322,6 @@ def run(config):
322322
)
323323

324324
num_channels_latents = transformer.in_channels // 4
325-
latents, latent_image_ids = prepare_latents(
326-
batch_size=global_batch_size,
327-
num_channels_latents=num_channels_latents,
328-
height=config.resolution,
329-
width=config.resolution,
330-
dtype=jnp.bfloat16,
331-
vae_scale_factor=vae_scale_factor,
332-
rng=rng,
333-
)
334325

335326
# LOAD TEXT ENCODERS
336327
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
@@ -352,17 +343,6 @@ def run(config):
352343
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
353344
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
354345

355-
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
356-
prompt=config.prompt,
357-
prompt_2=config.prompt_2,
358-
clip_tokenizer=clip_tokenizer,
359-
clip_text_encoder=clip_text_encoder,
360-
t5_tokenizer=t5_tokenizer,
361-
t5_text_encoder=t5_encoder,
362-
num_images_per_prompt=global_batch_size,
363-
max_sequence_length=config.max_sequence_length,
364-
)
365-
366346
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
367347
print("latents.shape: ", latents.shape, latents.dtype)
368348
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
@@ -374,13 +354,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
374354

375355
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
376356

377-
# move inputs to device and shard
378-
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
379-
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
380-
text_ids = jax.device_put(text_ids)
381-
guidance = jax.device_put(guidance, data_sharding)
382-
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
383-
384357
if config.offload_encoders:
385358
cpus = jax.devices("cpu")
386359
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])
@@ -427,58 +400,110 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
427400
states["transformer"] = transformer_state
428401
states["vae"] = vae_state
429402

430-
#validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
431-
432-
resolutions = [1024, 768, 512]
403+
resolutions = [
404+
(768, 768),
405+
(768, 1024),
406+
(1024, 768),
407+
(1024, 1024),
408+
(896, 1152),
409+
(1152, 896),
410+
(1920, 1080),
411+
(1080, 1920)
412+
]
433413
p_jitted = {}
434414
for resolution in resolutions:
435-
latents, latent_image_ids = prepare_latents(
436-
batch_size=global_batch_size,
437-
num_channels_latents=num_channels_latents,
438-
height=resolution,
439-
width=resolution,
440-
dtype=jnp.bfloat16,
441-
vae_scale_factor=vae_scale_factor,
442-
rng=rng,
443-
)
444-
latents = jax.device_put(latents, data_sharding)
445-
latent_image_ids = jax.device_put(latent_image_ids)
446-
447-
# Setup timesteps
448-
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
449-
# shifting the schedule to favor high timesteps for higher signal images
450-
if config.time_shift:
451-
# estimate mu based on linear estimation between two points
452-
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
453-
mu = lin_function(latents.shape[1])
454-
timesteps = time_shift(mu, 1.0, timesteps)
455-
c_ts = timesteps[:-1]
456-
p_ts = timesteps[1:]
457-
458-
p_run_inference = jax.jit(
459-
functools.partial(
460-
run_inference,
461-
transformer=transformer,
462-
vae=vae,
463-
config=config,
464-
resolution=resolution,
465-
mesh=mesh,
466-
latents=latents,
467-
latent_image_ids=latent_image_ids,
468-
prompt_embeds=prompt_embeds,
469-
txt_ids=text_ids,
470-
vec=pooled_prompt_embeds,
471-
guidance_vec=guidance,
472-
c_ts=c_ts,
473-
p_ts=p_ts,
474-
),
475-
in_shardings=(state_shardings,),
476-
out_shardings=None,
477-
)
478-
with ExitStack() as stack:
479-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
480-
p_run_inference(states).block_until_ready()
481-
p_jitted[resolution] = p_run_inference
415+
max_logging.log(f"Resolutions: {resolution}")
416+
for _ in range(5):
417+
s0 = time.perf_counter()
418+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
419+
prompt=config.prompt,
420+
prompt_2=config.prompt_2,
421+
clip_tokenizer=clip_tokenizer,
422+
clip_text_encoder=clip_text_encoder,
423+
t5_tokenizer=t5_tokenizer,
424+
t5_text_encoder=t5_encoder,
425+
num_images_per_prompt=global_batch_size,
426+
max_sequence_length=config.max_sequence_length,
427+
)
428+
max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}")
429+
latents, latent_image_ids = prepare_latents(
430+
batch_size=global_batch_size,
431+
num_channels_latents=num_channels_latents,
432+
height=resolution[1],
433+
width=resolution[0],
434+
dtype=jnp.bfloat16,
435+
vae_scale_factor=vae_scale_factor,
436+
rng=rng,
437+
)
438+
439+
# move inputs to device and shard
440+
s0 = time.perf_counter()
441+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
442+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
443+
text_ids = jax.device_put(text_ids)
444+
guidance = jax.device_put(guidance, data_sharding)
445+
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
446+
latents = jax.device_put(latents, data_sharding)
447+
latent_image_ids = jax.device_put(latent_image_ids)
448+
max_logging.log(f"Moving to device time: {(time.perf_counter() - s0)}")
449+
450+
# Setup timesteps
451+
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
452+
# shifting the schedule to favor high timesteps for higher signal images
453+
if config.time_shift:
454+
# estimate mu based on linear estimation between two points
455+
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
456+
mu = lin_function(latents.shape[1])
457+
timesteps = time_shift(mu, 1.0, timesteps)
458+
c_ts = timesteps[:-1]
459+
p_ts = timesteps[1:]
460+
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
461+
p_run_inference = p_jitted.get(resolution, None)
462+
if p_run_inference is None:
463+
print("FN not found, compiling...")
464+
p_run_inference = jax.jit(
465+
functools.partial(
466+
run_inference,
467+
transformer=transformer,
468+
vae=vae,
469+
config=config,
470+
resolution=resolution,
471+
mesh=mesh,
472+
latents=latents,
473+
latent_image_ids=latent_image_ids,
474+
prompt_embeds=prompt_embeds,
475+
txt_ids=text_ids,
476+
vec=pooled_prompt_embeds,
477+
guidance_vec=guidance,
478+
c_ts=c_ts,
479+
p_ts=p_ts,
480+
),
481+
)
482+
p_jitted[resolution] = p_run_inference
483+
with ExitStack() as stack:
484+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
485+
s0 = time.perf_counter()
486+
imgs = p_run_inference(
487+
states,
488+
latents = latents,
489+
latent_image_ids=latent_image_ids,
490+
prompt_embeds=prompt_embeds,
491+
txt_ids=text_ids,
492+
vec=pooled_prompt_embeds,
493+
).block_until_ready()
494+
max_logging.log(f"inference time: {(time.perf_counter() - s0)}")
495+
s0 = time.perf_counter()
496+
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
497+
max_logging.log(f"Gathering all time: {(time.perf_counter() - s0)}")
498+
s0 = time.perf_counter()
499+
imgs = np.array(imgs)
500+
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
501+
imgs = np.transpose(imgs, (0, 2, 3, 1))
502+
imgs = np.uint8(imgs * 255)
503+
for i, image in enumerate(imgs):
504+
Image.fromarray(image).save(f"flux_{resolution[0]}_{resolution[1]}_{i}.png")
505+
max_logging.log(f"Saving images time: {(time.perf_counter() - s0)}")
506+
get_memory_allocations()
482507
breakpoint()
483508
t0 = time.perf_counter()
484509
with ExitStack() as stack:

0 commit comments

Comments
 (0)