Skip to content

Commit eef5dd3

Browse files
committed
Fix flux inference and sdxl training
1 parent 1e1d2e1 commit eef5dd3

2 files changed

Lines changed: 264 additions & 262 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 189 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -287,207 +287,207 @@ def run(config):
287287
global_batch_size = config.per_device_batch_size * jax.local_device_count()
288288

289289
# LOAD VAE
290+
with mesh:
291+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
292+
config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
293+
)
290294

291-
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
292-
config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
293-
)
295+
weights_init_fn = functools.partial(vae.init_weights, rng=rng)
296+
vae_state, vae_state_shardings = setup_initial_state(
297+
model=vae,
298+
tx=None,
299+
config=config,
300+
mesh=mesh,
301+
weights_init_fn=weights_init_fn,
302+
model_params=vae_params,
303+
training=False,
304+
)
294305

295-
weights_init_fn = functools.partial(vae.init_weights, rng=rng)
296-
vae_state, vae_state_shardings = setup_initial_state(
297-
model=vae,
298-
tx=None,
299-
config=config,
300-
mesh=mesh,
301-
weights_init_fn=weights_init_fn,
302-
model_params=vae_params,
303-
training=False,
304-
)
306+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
305307

306-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
307-
308-
# LOAD TRANSFORMER
309-
flash_block_sizes = get_flash_block_sizes(config)
310-
transformer = FluxTransformer2DModel.from_config(
311-
config.pretrained_model_name_or_path,
312-
subfolder="transformer",
313-
mesh=mesh,
314-
split_head_dim=config.split_head_dim,
315-
attention_kernel=config.attention,
316-
flash_block_sizes=flash_block_sizes,
317-
dtype=config.activations_dtype,
318-
weights_dtype=config.weights_dtype,
319-
precision=get_precision(config),
320-
)
308+
# LOAD TRANSFORMER
309+
flash_block_sizes = get_flash_block_sizes(config)
310+
transformer = FluxTransformer2DModel.from_config(
311+
config.pretrained_model_name_or_path,
312+
subfolder="transformer",
313+
mesh=mesh,
314+
split_head_dim=config.split_head_dim,
315+
attention_kernel=config.attention,
316+
flash_block_sizes=flash_block_sizes,
317+
dtype=config.activations_dtype,
318+
weights_dtype=config.weights_dtype,
319+
precision=get_precision(config),
320+
)
321321

322-
num_channels_latents = transformer.in_channels // 4
323-
latents, latent_image_ids = prepare_latents(
324-
batch_size=global_batch_size,
325-
num_channels_latents=num_channels_latents,
326-
height=config.resolution,
327-
width=config.resolution,
328-
dtype=jnp.bfloat16,
329-
vae_scale_factor=vae_scale_factor,
330-
rng=rng,
331-
)
322+
num_channels_latents = transformer.in_channels // 4
323+
latents, latent_image_ids = prepare_latents(
324+
batch_size=global_batch_size,
325+
num_channels_latents=num_channels_latents,
326+
height=config.resolution,
327+
width=config.resolution,
328+
dtype=jnp.bfloat16,
329+
vae_scale_factor=vae_scale_factor,
330+
rng=rng,
331+
)
332332

333-
# LOAD TEXT ENCODERS
334-
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
335-
config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype
336-
)
337-
clip_tokenizer = CLIPTokenizer.from_pretrained(
338-
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
339-
)
333+
# LOAD TEXT ENCODERS
334+
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
335+
config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype
336+
)
337+
clip_tokenizer = CLIPTokenizer.from_pretrained(
338+
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
339+
)
340340

341-
t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype)
342-
t5_tokenizer = AutoTokenizer.from_pretrained(
343-
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344-
)
341+
t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype)
342+
t5_tokenizer = AutoTokenizer.from_pretrained(
343+
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344+
)
345345

346-
encoders_sharding = NamedSharding(mesh, P())
347-
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348-
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349-
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
350-
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
351-
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
352-
353-
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
354-
prompt=config.prompt,
355-
prompt_2=config.prompt_2,
356-
clip_tokenizer=clip_tokenizer,
357-
clip_text_encoder=clip_text_encoder,
358-
t5_tokenizer=t5_tokenizer,
359-
t5_text_encoder=t5_encoder,
360-
num_images_per_prompt=global_batch_size,
361-
max_sequence_length=config.max_sequence_length,
362-
)
346+
encoders_sharding = NamedSharding(mesh, P())
347+
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348+
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349+
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
350+
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
351+
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
352+
353+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
354+
prompt=config.prompt,
355+
prompt_2=config.prompt_2,
356+
clip_tokenizer=clip_tokenizer,
357+
clip_text_encoder=clip_text_encoder,
358+
t5_tokenizer=t5_tokenizer,
359+
t5_text_encoder=t5_encoder,
360+
num_images_per_prompt=global_batch_size,
361+
max_sequence_length=config.max_sequence_length,
362+
)
363363

364-
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
365-
print("latents.shape: ", latents.shape, latents.dtype)
366-
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
367-
print("text_ids.shape: ", text_ids.shape, text_ids.dtype)
368-
print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype)
369-
print("timesteps.shape: ", timesteps.shape, timesteps.dtype)
370-
print("guidance.shape: ", guidance.shape, guidance.dtype)
371-
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
372-
373-
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
374-
375-
# move inputs to device and shard
376-
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
377-
latents = jax.device_put(latents, data_sharding)
378-
latent_image_ids = jax.device_put(latent_image_ids)
379-
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
380-
text_ids = jax.device_put(text_ids)
381-
guidance = jax.device_put(guidance, data_sharding)
382-
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
383-
384-
if config.offload_encoders:
385-
cpus = jax.devices("cpu")
386-
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])
387-
388-
get_memory_allocations()
389-
# evaluate shapes
390-
transformer_eval_params = transformer.init_weights(
391-
rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True
392-
)
364+
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
365+
print("latents.shape: ", latents.shape, latents.dtype)
366+
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
367+
print("text_ids.shape: ", text_ids.shape, text_ids.dtype)
368+
print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype)
369+
print("timesteps.shape: ", timesteps.shape, timesteps.dtype)
370+
print("guidance.shape: ", guidance.shape, guidance.dtype)
371+
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
372+
373+
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
374+
375+
# move inputs to device and shard
376+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
377+
latents = jax.device_put(latents, data_sharding)
378+
latent_image_ids = jax.device_put(latent_image_ids)
379+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
380+
text_ids = jax.device_put(text_ids)
381+
guidance = jax.device_put(guidance, data_sharding)
382+
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
383+
384+
if config.offload_encoders:
385+
cpus = jax.devices("cpu")
386+
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])
387+
388+
get_memory_allocations()
389+
# evaluate shapes
390+
transformer_eval_params = transformer.init_weights(
391+
rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True
392+
)
393393

394-
# loads pretrained weights
395-
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
396-
params = {}
397-
params["transformer"] = transformer_params
398-
# maybe load lora and create interceptor
399-
lora_loader = FluxLoraLoaderMixin()
400-
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
401-
transformer_params = params["transformer"]
402-
# create transformer state
403-
weights_init_fn = functools.partial(
404-
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
405-
)
406-
with ExitStack() as stack:
407-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
408-
transformer_state, transformer_state_shardings = setup_initial_state(
409-
model=transformer,
410-
tx=None,
411-
config=config,
412-
mesh=mesh,
413-
weights_init_fn=weights_init_fn,
414-
model_params=None,
415-
training=False,
394+
# loads pretrained weights
395+
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
396+
params = {}
397+
params["transformer"] = transformer_params
398+
# maybe load lora and create interceptor
399+
lora_loader = FluxLoraLoaderMixin()
400+
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
401+
transformer_params = params["transformer"]
402+
# create transformer state
403+
weights_init_fn = functools.partial(
404+
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
416405
)
417-
transformer_state = transformer_state.replace(params=transformer_params)
418-
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
419-
get_memory_allocations()
420-
421-
states = {}
422-
state_shardings = {}
423-
424-
state_shardings["transformer"] = transformer_state_shardings
425-
state_shardings["vae"] = vae_state_shardings
426-
427-
states["transformer"] = transformer_state
428-
states["vae"] = vae_state
429-
430-
# Setup timesteps
431-
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
432-
# shifting the schedule to favor high timesteps for higher signal images
433-
if config.time_shift:
434-
# estimate mu based on linear estimation between two points
435-
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
436-
mu = lin_function(latents.shape[1])
437-
timesteps = time_shift(mu, 1.0, timesteps)
438-
c_ts = timesteps[:-1]
439-
p_ts = timesteps[1:]
440-
441-
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
442-
443-
p_run_inference = jax.jit(
444-
functools.partial(
445-
run_inference,
446-
transformer=transformer,
447-
vae=vae,
406+
with ExitStack() as stack:
407+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
408+
transformer_state, transformer_state_shardings = setup_initial_state(
409+
model=transformer,
410+
tx=None,
448411
config=config,
449412
mesh=mesh,
450-
latents=latents,
451-
latent_image_ids=latent_image_ids,
452-
prompt_embeds=prompt_embeds,
453-
txt_ids=text_ids,
454-
vec=pooled_prompt_embeds,
455-
guidance_vec=guidance,
456-
c_ts=c_ts,
457-
p_ts=p_ts,
458-
),
459-
in_shardings=(state_shardings,),
460-
out_shardings=None,
461-
)
462-
t0 = time.perf_counter()
463-
with ExitStack() as stack:
464-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
465-
p_run_inference(states).block_until_ready()
466-
t1 = time.perf_counter()
467-
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
468-
469-
t0 = time.perf_counter()
470-
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
471-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
472-
imgs = p_run_inference(states).block_until_ready()
473-
t1 = time.perf_counter()
474-
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
475-
476-
t0 = time.perf_counter()
477-
with ExitStack() as stack:
478-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
479-
imgs = p_run_inference(states).block_until_ready()
480-
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
481-
t1 = time.perf_counter()
482-
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
483-
imgs = np.array(imgs)
484-
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
485-
imgs = np.transpose(imgs, (0, 2, 3, 1))
486-
imgs = np.uint8(imgs * 255)
487-
for i, image in enumerate(imgs):
488-
Image.fromarray(image).save(f"flux_{i}.png")
489-
490-
return imgs
413+
weights_init_fn=weights_init_fn,
414+
model_params=None,
415+
training=False,
416+
)
417+
transformer_state = transformer_state.replace(params=transformer_params)
418+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
419+
get_memory_allocations()
420+
421+
states = {}
422+
state_shardings = {}
423+
424+
state_shardings["transformer"] = transformer_state_shardings
425+
state_shardings["vae"] = vae_state_shardings
426+
427+
states["transformer"] = transformer_state
428+
states["vae"] = vae_state
429+
430+
# Setup timesteps
431+
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
432+
# shifting the schedule to favor high timesteps for higher signal images
433+
if config.time_shift:
434+
# estimate mu based on linear estimation between two points
435+
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
436+
mu = lin_function(latents.shape[1])
437+
timesteps = time_shift(mu, 1.0, timesteps)
438+
c_ts = timesteps[:-1]
439+
p_ts = timesteps[1:]
440+
441+
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
442+
443+
p_run_inference = jax.jit(
444+
functools.partial(
445+
run_inference,
446+
transformer=transformer,
447+
vae=vae,
448+
config=config,
449+
mesh=mesh,
450+
latents=latents,
451+
latent_image_ids=latent_image_ids,
452+
prompt_embeds=prompt_embeds,
453+
txt_ids=text_ids,
454+
vec=pooled_prompt_embeds,
455+
guidance_vec=guidance,
456+
c_ts=c_ts,
457+
p_ts=p_ts,
458+
),
459+
in_shardings=(state_shardings,),
460+
out_shardings=None,
461+
)
462+
t0 = time.perf_counter()
463+
with ExitStack() as stack:
464+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
465+
p_run_inference(states).block_until_ready()
466+
t1 = time.perf_counter()
467+
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
468+
469+
t0 = time.perf_counter()
470+
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
471+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
472+
imgs = p_run_inference(states).block_until_ready()
473+
t1 = time.perf_counter()
474+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
475+
476+
t0 = time.perf_counter()
477+
with ExitStack() as stack:
478+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
479+
imgs = p_run_inference(states).block_until_ready()
480+
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
481+
t1 = time.perf_counter()
482+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
483+
imgs = np.array(imgs)
484+
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
485+
imgs = np.transpose(imgs, (0, 2, 3, 1))
486+
imgs = np.uint8(imgs * 255)
487+
for i, image in enumerate(imgs):
488+
Image.fromarray(image).save(f"flux_{i}.png")
489+
490+
return imgs
491491

492492

493493
def main(argv: Sequence[str]) -> None:

0 commit comments

Comments
 (0)