Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 189 additions & 189 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,207 +287,207 @@ def run(config):
global_batch_size = config.per_device_batch_size * jax.local_device_count()

# LOAD VAE
with mesh:
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
)

vae, vae_params = FlaxAutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
)
weights_init_fn = functools.partial(vae.init_weights, rng=rng)
vae_state, vae_state_shardings = setup_initial_state(
model=vae,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=vae_params,
training=False,
)

weights_init_fn = functools.partial(vae.init_weights, rng=rng)
vae_state, vae_state_shardings = setup_initial_state(
model=vae,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=vae_params,
training=False,
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

# LOAD TRANSFORMER
flash_block_sizes = get_flash_block_sizes(config)
transformer = FluxTransformer2DModel.from_config(
config.pretrained_model_name_or_path,
subfolder="transformer",
mesh=mesh,
split_head_dim=config.split_head_dim,
attention_kernel=config.attention,
flash_block_sizes=flash_block_sizes,
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
precision=get_precision(config),
)
# LOAD TRANSFORMER
flash_block_sizes = get_flash_block_sizes(config)
transformer = FluxTransformer2DModel.from_config(
config.pretrained_model_name_or_path,
subfolder="transformer",
mesh=mesh,
split_head_dim=config.split_head_dim,
attention_kernel=config.attention,
flash_block_sizes=flash_block_sizes,
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
precision=get_precision(config),
)

num_channels_latents = transformer.in_channels // 4
latents, latent_image_ids = prepare_latents(
batch_size=global_batch_size,
num_channels_latents=num_channels_latents,
height=config.resolution,
width=config.resolution,
dtype=jnp.bfloat16,
vae_scale_factor=vae_scale_factor,
rng=rng,
)
num_channels_latents = transformer.in_channels // 4
latents, latent_image_ids = prepare_latents(
batch_size=global_batch_size,
num_channels_latents=num_channels_latents,
height=config.resolution,
width=config.resolution,
dtype=jnp.bfloat16,
vae_scale_factor=vae_scale_factor,
rng=rng,
)

# LOAD TEXT ENCODERS
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype
)
clip_tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
)
# LOAD TEXT ENCODERS
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype
)
clip_tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
)

t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype)
t5_tokenizer = AutoTokenizer.from_pretrained(
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)
t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype)
t5_tokenizer = AutoTokenizer.from_pretrained(
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)

encoders_sharding = NamedSharding(mesh, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)

prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
prompt=config.prompt,
prompt_2=config.prompt_2,
clip_tokenizer=clip_tokenizer,
clip_text_encoder=clip_text_encoder,
t5_tokenizer=t5_tokenizer,
t5_text_encoder=t5_encoder,
num_images_per_prompt=global_batch_size,
max_sequence_length=config.max_sequence_length,
)
encoders_sharding = NamedSharding(mesh, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)

prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
prompt=config.prompt,
prompt_2=config.prompt_2,
clip_tokenizer=clip_tokenizer,
clip_text_encoder=clip_text_encoder,
t5_tokenizer=t5_tokenizer,
t5_text_encoder=t5_encoder,
num_images_per_prompt=global_batch_size,
max_sequence_length=config.max_sequence_length,
)

def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
print("latents.shape: ", latents.shape, latents.dtype)
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
print("text_ids.shape: ", text_ids.shape, text_ids.dtype)
print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype)
print("timesteps.shape: ", timesteps.shape, timesteps.dtype)
print("guidance.shape: ", guidance.shape, guidance.dtype)
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)

guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)

# move inputs to device and shard
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
latents = jax.device_put(latents, data_sharding)
latent_image_ids = jax.device_put(latent_image_ids)
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
text_ids = jax.device_put(text_ids)
guidance = jax.device_put(guidance, data_sharding)
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)

if config.offload_encoders:
cpus = jax.devices("cpu")
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])

get_memory_allocations()
# evaluate shapes
transformer_eval_params = transformer.init_weights(
rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True
)
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
print("latents.shape: ", latents.shape, latents.dtype)
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
print("text_ids.shape: ", text_ids.shape, text_ids.dtype)
print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype)
print("timesteps.shape: ", timesteps.shape, timesteps.dtype)
print("guidance.shape: ", guidance.shape, guidance.dtype)
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)

guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)

# move inputs to device and shard
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
latents = jax.device_put(latents, data_sharding)
latent_image_ids = jax.device_put(latent_image_ids)
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
text_ids = jax.device_put(text_ids)
guidance = jax.device_put(guidance, data_sharding)
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)

if config.offload_encoders:
cpus = jax.devices("cpu")
t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0])

get_memory_allocations()
# evaluate shapes
transformer_eval_params = transformer.init_weights(
rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True
)

# loads pretrained weights
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
params = {}
params["transformer"] = transformer_params
# maybe load lora and create interceptor
lora_loader = FluxLoraLoaderMixin()
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
transformer_params = params["transformer"]
# create transformer state
weights_init_fn = functools.partial(
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
)
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
# loads pretrained weights
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
params = {}
params["transformer"] = transformer_params
# maybe load lora and create interceptor
lora_loader = FluxLoraLoaderMixin()
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
transformer_params = params["transformer"]
# create transformer state
weights_init_fn = functools.partial(
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
)
transformer_state = transformer_state.replace(params=transformer_params)
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
state_shardings = {}

state_shardings["transformer"] = transformer_state_shardings
state_shardings["vae"] = vae_state_shardings

states["transformer"] = transformer_state
states["vae"] = vae_state

# Setup timesteps
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if config.time_shift:
# estimate mu based on linear estimation between two points
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
mu = lin_function(latents.shape[1])
timesteps = time_shift(mu, 1.0, timesteps)
c_ts = timesteps[:-1]
p_ts = timesteps[1:]

validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)

p_run_inference = jax.jit(
functools.partial(
run_inference,
transformer=transformer,
vae=vae,
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
latents=latents,
latent_image_ids=latent_image_ids,
prompt_embeds=prompt_embeds,
txt_ids=text_ids,
vec=pooled_prompt_embeds,
guidance_vec=guidance,
c_ts=c_ts,
p_ts=p_ts,
),
in_shardings=(state_shardings,),
out_shardings=None,
)
t0 = time.perf_counter()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")

return imgs
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=transformer_params)
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
state_shardings = {}

state_shardings["transformer"] = transformer_state_shardings
state_shardings["vae"] = vae_state_shardings

states["transformer"] = transformer_state
states["vae"] = vae_state

# Setup timesteps
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if config.time_shift:
# estimate mu based on linear estimation between two points
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
mu = lin_function(latents.shape[1])
timesteps = time_shift(mu, 1.0, timesteps)
c_ts = timesteps[:-1]
p_ts = timesteps[1:]

validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)

p_run_inference = jax.jit(
functools.partial(
run_inference,
transformer=transformer,
vae=vae,
config=config,
mesh=mesh,
latents=latents,
latent_image_ids=latent_image_ids,
prompt_embeds=prompt_embeds,
txt_ids=text_ids,
vec=pooled_prompt_embeds,
guidance_vec=guidance,
c_ts=c_ts,
p_ts=p_ts,
),
in_shardings=(state_shardings,),
out_shardings=None,
)
t0 = time.perf_counter()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")

return imgs


def main(argv: Sequence[str]) -> None:
Expand Down
Loading
Loading