Skip to content

Commit 7d3dbd4

Browse files
precompile generate functions with different dimensions.
1 parent e177935 commit 7d3dbd4

1 file changed

Lines changed: 58 additions & 39 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def unpack(x: Array, height: int, width: int) -> Array:
7676
)
7777

7878

79-
def vae_decode(latents, vae, state, config):
80-
img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution)
79+
def vae_decode(latents, vae, state, config, resolution):
80+
img = unpack(x=latents.astype(jnp.float32), height=resolution, width=resolution)
8181
img = img / vae.config.scaling_factor + vae.config.shift_factor
8282
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
8383
return img
@@ -135,7 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo
135135

136136

137137
def run_inference(
138-
states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts
138+
states, transformer, vae, config, resolution, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts
139139
):
140140

141141
transformer_state = states["transformer"]
@@ -150,7 +150,7 @@ def run_inference(
150150
vec=vec,
151151
guidance_vec=guidance_vec,
152152
)
153-
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
153+
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config, resolution=resolution)
154154

155155
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
156156
latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts))
@@ -376,8 +376,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
376376

377377
# move inputs to device and shard
378378
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
379-
latents = jax.device_put(latents, data_sharding)
380-
latent_image_ids = jax.device_put(latent_image_ids)
381379
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
382380
text_ids = jax.device_put(text_ids)
383381
guidance = jax.device_put(guidance, data_sharding)
@@ -429,45 +427,66 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
429427
states["transformer"] = transformer_state
430428
states["vae"] = vae_state
431429

432-
# Setup timesteps
433-
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
434-
# shifting the schedule to favor high timesteps for higher signal images
435-
if config.time_shift:
436-
# estimate mu based on linear estimation between two points
437-
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
438-
mu = lin_function(latents.shape[1])
439-
timesteps = time_shift(mu, 1.0, timesteps)
440-
c_ts = timesteps[:-1]
441-
p_ts = timesteps[1:]
442-
443-
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
444-
445-
p_run_inference = jax.jit(
446-
functools.partial(
447-
run_inference,
448-
transformer=transformer,
449-
vae=vae,
450-
config=config,
451-
mesh=mesh,
452-
latents=latents,
453-
latent_image_ids=latent_image_ids,
454-
prompt_embeds=prompt_embeds,
455-
txt_ids=text_ids,
456-
vec=pooled_prompt_embeds,
457-
guidance_vec=guidance,
458-
c_ts=c_ts,
459-
p_ts=p_ts,
460-
),
461-
in_shardings=(state_shardings,),
462-
out_shardings=None,
463-
)
430+
#validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
431+
432+
resolutions = [1024, 768, 512]
433+
p_jitted = {}
434+
for resolution in resolutions:
435+
latents, latent_image_ids = prepare_latents(
436+
batch_size=global_batch_size,
437+
num_channels_latents=num_channels_latents,
438+
height=resolution,
439+
width=resolution,
440+
dtype=jnp.bfloat16,
441+
vae_scale_factor=vae_scale_factor,
442+
rng=rng,
443+
)
444+
latents = jax.device_put(latents, data_sharding)
445+
latent_image_ids = jax.device_put(latent_image_ids)
446+
447+
# Setup timesteps
448+
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
449+
# shifting the schedule to favor high timesteps for higher signal images
450+
if config.time_shift:
451+
# estimate mu based on linear estimation between two points
452+
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
453+
mu = lin_function(latents.shape[1])
454+
timesteps = time_shift(mu, 1.0, timesteps)
455+
c_ts = timesteps[:-1]
456+
p_ts = timesteps[1:]
457+
458+
p_run_inference = jax.jit(
459+
functools.partial(
460+
run_inference,
461+
transformer=transformer,
462+
vae=vae,
463+
config=config,
464+
resolution=resolution,
465+
mesh=mesh,
466+
latents=latents,
467+
latent_image_ids=latent_image_ids,
468+
prompt_embeds=prompt_embeds,
469+
txt_ids=text_ids,
470+
vec=pooled_prompt_embeds,
471+
guidance_vec=guidance,
472+
c_ts=c_ts,
473+
p_ts=p_ts,
474+
),
475+
in_shardings=(state_shardings,),
476+
out_shardings=None,
477+
)
478+
with ExitStack() as stack:
479+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
480+
p_run_inference(states).block_until_ready()
481+
p_jitted[resolution] = p_run_inference
482+
breakpoint()
464483
t0 = time.perf_counter()
465484
with ExitStack() as stack:
466485
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
467486
p_run_inference(states).block_until_ready()
468487
t1 = time.perf_counter()
469488
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
470-
489+
breakpoint()
471490
t0 = time.perf_counter()
472491
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
473492
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]

0 commit comments

Comments
 (0)