Skip to content

Commit e0f8163

Browse files
formatting
1 parent c8b71f1 commit e0f8163

2 files changed

Lines changed: 53 additions & 45 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,4 +496,4 @@ def main(argv: Sequence[str]) -> None:
496496

497497

498498
if __name__ == "__main__":
499-
app.run(main)
499+
app.run(main)

src/maxdiffusion/generate_flux_multi_res.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def unpack(x: Array, height: int, width: int, vae_scale_factor: int) -> Array:
7676

7777
return x
7878

79+
7980
def vae_decode(latents, vae, state, vae_scale_factor, resolution):
8081
img = unpack(x=latents.astype(jnp.float32), height=resolution[0], width=resolution[1], vae_scale_factor=vae_scale_factor)
8182
img = img / vae.config.scaling_factor + vae.config.shift_factor
@@ -127,18 +128,16 @@ def prepare_latent_image_ids(height, width):
127128
def time_shift(mu: float, sigma: float, t: Array):
128129
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
129130

131+
130132
def calculate_shift(
131-
image_seq_len,
132-
base_seq_len: int = 256,
133-
max_seq_len: int = 4096,
134-
base_shift: float = 0.5,
135-
max_shift: float = 1.16
133+
image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16
136134
):
137135
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
138136
b = base_shift - m * base_seq_len
139137
mu = image_seq_len * m + b
140138
return mu
141139

140+
142141
def run_inference(
143142
states,
144143
transformer,
@@ -154,7 +153,7 @@ def run_inference(
154153
guidance_vec,
155154
c_ts,
156155
p_ts,
157-
vae_scale_factor
156+
vae_scale_factor,
158157
):
159158

160159
transformer_state = states["transformer"]
@@ -169,7 +168,9 @@ def run_inference(
169168
vec=vec,
170169
guidance_vec=guidance_vec,
171170
)
172-
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, vae_scale_factor=vae_scale_factor, resolution=resolution)
171+
vae_decode_p = functools.partial(
172+
vae_decode, vae=vae, state=vae_state, vae_scale_factor=vae_scale_factor, resolution=resolution
173+
)
173174

174175
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
175176
latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts))
@@ -211,6 +212,7 @@ def prepare_latents(
211212

212213
return latents, latent_image_ids
213214

215+
214216
def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer):
215217
prompt = [prompt] if isinstance(prompt, str) else prompt
216218
text_inputs = tokenizer(
@@ -224,6 +226,7 @@ def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer):
224226
)
225227
return text_inputs.input_ids
226228

229+
227230
def get_clip_prompt_embeds(
228231
prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel
229232
):
@@ -246,6 +249,7 @@ def get_clip_prompt_embeds(
246249
prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1))
247250
return prompt_embeds
248251

252+
249253
def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_sequence_length: int = 512):
250254
prompt = [prompt] if isinstance(prompt, str) else prompt
251255
text_inputs = tokenizer(
@@ -259,6 +263,7 @@ def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_seq
259263
)
260264
return text_inputs.input_ids
261265

266+
262267
def get_t5_prompt_embeds(
263268
prompt: Union[str, List[str]],
264269
num_images_per_prompt: int,
@@ -288,6 +293,7 @@ def get_t5_prompt_embeds(
288293
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))
289294
return prompt_embeds
290295

296+
291297
def encode_prompt(
292298
prompt: Union[str, List[str]],
293299
prompt_2: Union[str, List[str]],
@@ -318,6 +324,7 @@ def encode_prompt(
318324
text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
319325
return prompt_embeds, pooled_prompt_embeds, text_ids
320326

327+
321328
def run(config):
322329
from maxdiffusion.models.flux.util import load_flow_model
323330

@@ -436,19 +443,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
436443
states["vae"] = vae_state
437444
# some resolutions from https://www.reddit.com/r/StableDiffusion/comments/1enxdga/flux_recommended_resolutions_from_01_to_20/
438445
resolutions = [
439-
(768, 768),
440-
(768, 1024),
441-
(1024, 768),
442-
(1024, 1024),
443-
(1408, 1408),
444-
(1728, 1152),
445-
(1152, 1728),
446-
(1664, 1216),
447-
(1216, 1664),
448-
(1920, 1088),
449-
(1088, 1920),
450-
(2176, 960),
451-
(960, 2176)
446+
(768, 768),
447+
(768, 1024),
448+
(1024, 768),
449+
(1024, 1024),
450+
(1408, 1408),
451+
(1728, 1152),
452+
(1152, 1728),
453+
(1664, 1216),
454+
(1216, 1664),
455+
(1920, 1088),
456+
(1088, 1920),
457+
(2176, 960),
458+
(960, 2176),
452459
]
453460
p_jitted = {}
454461
recorded_times = {}
@@ -461,14 +468,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
461468
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
462469
max_logging.log(f"Moving encoder to TPU time: {(time.perf_counter() - s0)}")
463470
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
464-
prompt=config.prompt,
465-
prompt_2=config.prompt_2,
466-
clip_tokenizer=clip_tokenizer,
467-
clip_text_encoder=clip_text_encoder,
468-
t5_tokenizer=t5_tokenizer,
469-
t5_text_encoder=t5_encoder,
470-
num_images_per_prompt=global_batch_size,
471-
max_sequence_length=config.max_sequence_length,
471+
prompt=config.prompt,
472+
prompt_2=config.prompt_2,
473+
clip_tokenizer=clip_tokenizer,
474+
clip_text_encoder=clip_text_encoder,
475+
t5_tokenizer=t5_tokenizer,
476+
t5_text_encoder=t5_encoder,
477+
num_images_per_prompt=global_batch_size,
478+
max_sequence_length=config.max_sequence_length,
472479
)
473480
if config.offload_encoders:
474481
s1 = time.perf_counter()
@@ -478,15 +485,15 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
478485
text_encoding_time_final = time.perf_counter() - s0
479486
max_logging.log(f"text encoding time: {text_encoding_time_final}")
480487
latents, latent_image_ids = prepare_latents(
481-
batch_size=global_batch_size,
482-
num_channels_latents=num_channels_latents,
483-
height=resolution[0],
484-
width=resolution[1],
485-
dtype=jnp.bfloat16,
486-
vae_scale_factor=vae_scale_factor,
487-
rng=rng,
488+
batch_size=global_batch_size,
489+
num_channels_latents=num_channels_latents,
490+
height=resolution[0],
491+
width=resolution[1],
492+
dtype=jnp.bfloat16,
493+
vae_scale_factor=vae_scale_factor,
494+
rng=rng,
488495
)
489-
496+
490497
# move inputs to device and shard
491498
s0 = time.perf_counter()
492499
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
@@ -509,7 +516,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
509516
timesteps = time_shift(mu, 1.0, timesteps)
510517
c_ts = timesteps[:-1]
511518
p_ts = timesteps[1:]
512-
#validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
519+
# validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
513520
p_run_inference = p_jitted.get(resolution, None)
514521
if p_run_inference is None:
515522
print("FN not found, compiling...")
@@ -537,14 +544,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
537544
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
538545
s0 = time.perf_counter()
539546
imgs = p_run_inference(
540-
states,
541-
latents = latents,
542-
latent_image_ids=latent_image_ids,
543-
prompt_embeds=prompt_embeds,
544-
txt_ids=text_ids,
545-
vec=pooled_prompt_embeds,
547+
states,
548+
latents=latents,
549+
latent_image_ids=latent_image_ids,
550+
prompt_embeds=prompt_embeds,
551+
txt_ids=text_ids,
552+
vec=pooled_prompt_embeds,
546553
).block_until_ready()
547-
recorded_times[resolution] = (time.perf_counter() - s0)
554+
recorded_times[resolution] = time.perf_counter() - s0
548555
max_logging.log(f"inference time: {recorded_times[resolution]}")
549556
s0 = time.perf_counter()
550557
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
@@ -566,6 +573,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
566573

567574
return imgs
568575

576+
569577
def main(argv: Sequence[str]) -> None:
570578
pyconfig.initialize(argv)
571579
run(pyconfig.config)

0 commit comments

Comments
 (0)