|
16 | 16 |
|
17 | 17 | from typing import Callable, List, Union, Sequence |
18 | 18 | from absl import app |
| 19 | +from contextlib import ExitStack |
19 | 20 | import functools |
20 | 21 | import math |
21 | 22 | import time |
|
24 | 25 | import jax |
25 | 26 | from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P |
26 | 27 | import jax.numpy as jnp |
| 28 | +import flax.linen as nn |
27 | 29 | from chex import Array |
28 | 30 | from einops import rearrange |
29 | 31 | from flax.linen import partitioning as nn_partitioning |
|
39 | 41 | get_precision, |
40 | 42 | setup_initial_state, |
41 | 43 | ) |
| 44 | +from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin |
| 45 | + |
| 46 | +def maybe_load_flux_lora(config, lora_loader, params): |
| 47 | + def _noop_interceptor(next_fn, args, kwargs, context): |
| 48 | + return next_fn(*args, **kwargs) |
| 49 | + |
| 50 | + lora_config = config.lora_config |
| 51 | + interceptors= [_noop_interceptor] |
| 52 | + if len(lora_config["lora_model_name_or_path"]) > 0: |
| 53 | + interceptors = [] |
| 54 | + for i in range(len(lora_config["lora_model_name_or_path"])): |
| 55 | + params, rank, network_alphas = lora_loader.load_lora_weights( |
| 56 | + config, |
| 57 | + lora_config["lora_model_name_or_path"][i], |
| 58 | + weight_name=lora_config["weight_name"][i], |
| 59 | + params=params, |
| 60 | + adapter_name=lora_config["adapter_name"][i], |
| 61 | + ) |
| 62 | + interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) |
| 63 | + interceptors.append(interceptor) |
| 64 | + |
| 65 | + return params, interceptors |
42 | 66 |
|
43 | 67 |
|
44 | 68 | def unpack(x: Array, height: int, width: int) -> Array: |
@@ -400,21 +424,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep |
400 | 424 |
|
401 | 425 | # loads pretrained weights |
402 | 426 | transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") |
| 427 | + params = {} |
| 428 | + params["transformer"] = transformer_params |
| 429 | + # maybe load lora and create interceptor |
| 430 | + lora_loader = FluxLoraLoaderMixin() |
| 431 | + params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) |
| 432 | + transformer_params = params["transformer"] |
403 | 433 | # create transformer state |
404 | 434 | weights_init_fn = functools.partial( |
405 | 435 | transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False |
406 | 436 | ) |
407 | | - transformer_state, transformer_state_shardings = setup_initial_state( |
408 | | - model=transformer, |
409 | | - tx=None, |
410 | | - config=config, |
411 | | - mesh=mesh, |
412 | | - weights_init_fn=weights_init_fn, |
413 | | - model_params=None, |
414 | | - training=False, |
415 | | - ) |
416 | | - transformer_state = transformer_state.replace(params=transformer_params) |
417 | | - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) |
| 437 | + with ExitStack() as stack: |
| 438 | + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] |
| 439 | + transformer_state, transformer_state_shardings = setup_initial_state( |
| 440 | + model=transformer, |
| 441 | + tx=None, |
| 442 | + config=config, |
| 443 | + mesh=mesh, |
| 444 | + weights_init_fn=weights_init_fn, |
| 445 | + model_params=None, |
| 446 | + training=False, |
| 447 | + ) |
| 448 | + transformer_state = transformer_state.replace(params=transformer_params) |
| 449 | + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) |
418 | 450 | get_memory_allocations() |
419 | 451 |
|
420 | 452 | states = {} |
@@ -444,17 +476,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep |
444 | 476 | out_shardings=None, |
445 | 477 | ) |
446 | 478 | t0 = time.perf_counter() |
447 | | - p_run_inference(states).block_until_ready() |
| 479 | + with ExitStack() as stack: |
| 480 | + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] |
| 481 | + p_run_inference(states).block_until_ready() |
448 | 482 | t1 = time.perf_counter() |
449 | 483 | max_logging.log(f"Compile time: {t1 - t0:.1f}s.") |
450 | 484 |
|
451 | 485 | t0 = time.perf_counter() |
452 | | - imgs = p_run_inference(states).block_until_ready() |
| 486 | + with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): |
| 487 | + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] |
| 488 | + imgs = p_run_inference(states).block_until_ready() |
453 | 489 | t1 = time.perf_counter() |
454 | 490 | max_logging.log(f"Inference time: {t1 - t0:.1f}s.") |
455 | 491 |
|
456 | 492 | t0 = time.perf_counter() |
457 | | - imgs = p_run_inference(states).block_until_ready() |
| 493 | + with ExitStack() as stack: |
| 494 | + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] |
| 495 | + imgs = p_run_inference(states).block_until_ready() |
458 | 496 | imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) |
459 | 497 | t1 = time.perf_counter() |
460 | 498 | max_logging.log(f"Inference time: {t1 - t0:.1f}s.") |
|
0 commit comments