diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index f2ae1b3f1..b248156ea 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -287,207 +287,207 @@ def run(config): global_batch_size = config.per_device_batch_size * jax.local_device_count() # LOAD VAE + with mesh: + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" + ) - vae, vae_params = FlaxAutoencoderKL.from_pretrained( - config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" - ) + weights_init_fn = functools.partial(vae.init_weights, rng=rng) + vae_state, vae_state_shardings = setup_initial_state( + model=vae, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=vae_params, + training=False, + ) - weights_init_fn = functools.partial(vae.init_weights, rng=rng) - vae_state, vae_state_shardings = setup_initial_state( - model=vae, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=vae_params, - training=False, - ) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - - # LOAD TRANSFORMER - flash_block_sizes = get_flash_block_sizes(config) - transformer = FluxTransformer2DModel.from_config( - config.pretrained_model_name_or_path, - subfolder="transformer", - mesh=mesh, - split_head_dim=config.split_head_dim, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - precision=get_precision(config), - ) + # LOAD TRANSFORMER + flash_block_sizes = get_flash_block_sizes(config) + transformer = FluxTransformer2DModel.from_config( + config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=mesh, + split_head_dim=config.split_head_dim, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + precision=get_precision(config), + ) - num_channels_latents = transformer.in_channels // 4 - latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=config.resolution, - width=config.resolution, - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, - ) + num_channels_latents = transformer.in_channels // 4 + latents, latent_image_ids = prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=config.resolution, + width=config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, + ) - # LOAD TEXT ENCODERS - clip_text_encoder = FlaxCLIPTextModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype - ) - clip_tokenizer = CLIPTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype - ) + # LOAD TEXT ENCODERS + clip_text_encoder = FlaxCLIPTextModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + ) - t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype) - t5_tokenizer = AutoTokenizer.from_pretrained( - config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True - ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True + ) - encoders_sharding = NamedSharding(mesh, P()) - partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) - clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) - clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) - - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length, - ) + encoders_sharding = NamedSharding(mesh, P()) + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) + clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) + clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, + ) - def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): - print("latents.shape: ", latents.shape, latents.dtype) - print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) - print("text_ids.shape: ", text_ids.shape, text_ids.dtype) - print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype) - print("timesteps.shape: ", timesteps.shape, timesteps.dtype) - print("guidance.shape: ", guidance.shape, guidance.dtype) - print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) - - guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - - # move inputs to device and shard - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - text_ids = jax.device_put(text_ids) - guidance = jax.device_put(guidance, data_sharding) - pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) - - if config.offload_encoders: - cpus = jax.devices("cpu") - t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) - - get_memory_allocations() - # evaluate shapes - transformer_eval_params = transformer.init_weights( - rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True - ) + def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): + print("latents.shape: ", latents.shape, latents.dtype) + print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) + print("text_ids.shape: ", text_ids.shape, text_ids.dtype) + print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype) + print("timesteps.shape: ", timesteps.shape, timesteps.dtype) + print("guidance.shape: ", guidance.shape, guidance.dtype) + print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) + + guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + + # move inputs to device and shard + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + text_ids = jax.device_put(text_ids) + guidance = jax.device_put(guidance, data_sharding) + pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + + if config.offload_encoders: + cpus = jax.devices("cpu") + t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) + + get_memory_allocations() + # evaluate shapes + transformer_eval_params = transformer.init_weights( + rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True + ) - # loads pretrained weights - transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") - params = {} - params["transformer"] = transformer_params - # maybe load lora and create interceptor - lora_loader = FluxLoraLoaderMixin() - params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) - transformer_params = params["transformer"] - # create transformer state - weights_init_fn = functools.partial( - transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False - ) - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, + # loads pretrained weights + transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") + params = {} + params["transformer"] = transformer_params + # maybe load lora and create interceptor + lora_loader = FluxLoraLoaderMixin() + params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) + transformer_params = params["transformer"] + # create transformer state + weights_init_fn = functools.partial( + transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False ) - transformer_state = transformer_state.replace(params=transformer_params) - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - state_shardings["vae"] = vae_state_shardings - - states["transformer"] = transformer_state - states["vae"] = vae_state - - # Setup timesteps - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps) - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - - validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, config=config, mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - c_ts=c_ts, - p_ts=p_ts, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - t0 = time.perf_counter() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - p_run_inference(states).block_until_ready() - t1 = time.perf_counter() - max_logging.log(f"Compile time: {t1 - t0:.1f}s.") - - t0 = time.perf_counter() - with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - imgs = p_run_inference(states).block_until_ready() - t1 = time.perf_counter() - max_logging.log(f"Inference time: {t1 - t0:.1f}s.") - - t0 = time.perf_counter() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - imgs = p_run_inference(states).block_until_ready() - imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) - t1 = time.perf_counter() - max_logging.log(f"Inference time: {t1 - t0:.1f}s.") - imgs = np.array(imgs) - imgs = (imgs * 0.5 + 0.5).clip(0, 1) - imgs = np.transpose(imgs, (0, 2, 3, 1)) - imgs = np.uint8(imgs * 255) - for i, image in enumerate(imgs): - Image.fromarray(image).save(f"flux_{i}.png") - - return imgs + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + state_shardings["vae"] = vae_state_shardings + + states["transformer"] = transformer_state + states["vae"] = vae_state + + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + t0 = time.perf_counter() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + p_run_inference(states).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Compile time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{i}.png") + + return imgs def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 8ede965a7..9ad1022d9 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -217,103 +217,105 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): def run(config): checkpoint_loader = GenerateSDXL(config) - pipeline, params = checkpoint_loader.load_checkpoint() + mesh = checkpoint_loader.mesh + with mesh: + pipeline, params = checkpoint_loader.load_checkpoint() - noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config) + noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config) - weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) - unboxed_abstract_state, _, _ = max_utils.get_abstract_state( - pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False - ) - - # load unet params from orbax checkpoint - unet_params = load_params_from_path( - config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state" - ) - if unet_params: - params["unet"] = unet_params - - # maybe load lora and create interceptor - params, lora_interceptors = maybe_load_sdxl_lora(config, pipeline, params) - - if config.lightning_repo: - pipeline, params = load_sdxllightning_unet(config, pipeline, params) - - # Don't restore the full train state, instead, just restore params - # and create an inference state. - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - unet_state, unet_state_shardings = max_utils.setup_initial_state( - model=pipeline.unet, - tx=None, - config=config, - mesh=checkpoint_loader.mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, + weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False ) - unet_state = unet_state.replace(params=params.get("unet", None)) - unet_state = jax.device_put(unet_state, unet_state_shardings) - vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( - pipeline, params, checkpoint_item_name="vae_state", is_training=False - ) - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( - pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False + # load unet params from orbax checkpoint + unet_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state" ) - - text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( - pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False - ) - states = {} - state_shardings = {} - - state_shardings["vae_state"] = vae_state_shardings - state_shardings["unet_state"] = unet_state_shardings - state_shardings["text_encoder_state"] = text_encoder_state_shardings - state_shardings["text_encoder_2_state"] = text_encoder_2_state_shardings - - states["unet_state"] = unet_state - states["vae_state"] = vae_state - states["text_encoder_state"] = text_encoder_state - states["text_encoder_2_state"] = text_encoder_2_state - - pipeline.scheduler = noise_scheduler - params["scheduler"] = noise_scheduler_state - - p_run_inference = jax.jit( - functools.partial( - run_inference, - pipeline=pipeline, - params=params, + if unet_params: + params["unet"] = unet_params + + # maybe load lora and create interceptor + params, lora_interceptors = maybe_load_sdxl_lora(config, pipeline, params) + + if config.lightning_repo: + pipeline, params = load_sdxllightning_unet(config, pipeline, params) + + # Don't restore the full train state, instead, just restore params + # and create an inference state. + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + unet_state, unet_state_shardings = max_utils.setup_initial_state( + model=pipeline.unet, + tx=None, config=config, - rng=checkpoint_loader.rng, mesh=checkpoint_loader.mesh, - batch_size=checkpoint_loader.total_train_batch_size, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + unet_state = unet_state.replace(params=params.get("unet", None)) + unet_state = jax.device_put(unet_state, unet_state_shardings) + + vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( + pipeline, params, checkpoint_item_name="vae_state", is_training=False + ) + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( + pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False + ) + + text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( + pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False + ) + states = {} + state_shardings = {} + + state_shardings["vae_state"] = vae_state_shardings + state_shardings["unet_state"] = unet_state_shardings + state_shardings["text_encoder_state"] = text_encoder_state_shardings + state_shardings["text_encoder_2_state"] = text_encoder_2_state_shardings + + states["unet_state"] = unet_state + states["vae_state"] = vae_state + states["text_encoder_state"] = text_encoder_state + states["text_encoder_2_state"] = text_encoder_2_state + + pipeline.scheduler = noise_scheduler + params["scheduler"] = noise_scheduler_state + + p_run_inference = jax.jit( + functools.partial( + run_inference, + pipeline=pipeline, + params=params, + config=config, + rng=checkpoint_loader.rng, + mesh=checkpoint_loader.mesh, + batch_size=checkpoint_loader.total_train_batch_size, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) - s = time.time() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - p_run_inference(states).block_until_ready() - print("compile time: ", (time.time() - s)) - s = time.time() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - images = p_run_inference(states).block_until_ready() - print("inference time: ", (time.time() - s)) - images = jax.experimental.multihost_utils.process_allgather(images, tiled=True) - numpy_images = np.array(images) - images = VaeImageProcessor.numpy_to_pil(numpy_images) - for i, image in enumerate(images): - image.save(f"image_sdxl_{i}.png") - - return images + s = time.time() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + p_run_inference(states).block_until_ready() + print("compile time: ", (time.time() - s)) + s = time.time() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + images = p_run_inference(states).block_until_ready() + print("inference time: ", (time.time() - s)) + images = jax.experimental.multihost_utils.process_allgather(images, tiled=True) + numpy_images = np.array(images) + images = VaeImageProcessor.numpy_to_pil(numpy_images) + for i, image in enumerate(images): + image.save(f"image_sdxl_{i}.png") + + return images def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index 81fca7404..78d0e8b2e 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -19,6 +19,7 @@ from typing import Any, Callable import jax from maxdiffusion import (max_utils, maxdiffusion_utils, max_logging) +from flax.linen import partitioning as nn_partitioning from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (BaseStableDiffusionCheckpointer) @@ -113,83 +114,84 @@ def calculate_tflops(self, pipeline, params): def start_training(self): # Hook - self.pre_training_steps() - # Load checkpoint - will load or create states - pipeline, params = self._time_and_log_call(self.load_checkpoint) - # create train states - train_states = {} - state_shardings = {} - vae_state, vae_state_mesh_shardings = self._time_and_log_call( - self.create_vae_state, - # Arguments for create_vae_state - pipeline=pipeline, - params=params, - checkpoint_item_name="vae_state", - is_training=False, - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + self.pre_training_steps() + # Load checkpoint - will load or create states + pipeline, params = self._time_and_log_call(self.load_checkpoint) + # create train states + train_states = {} + state_shardings = {} + vae_state, vae_state_mesh_shardings = self._time_and_log_call( + self.create_vae_state, + # Arguments for create_vae_state + pipeline=pipeline, + params=params, + checkpoint_item_name="vae_state", + is_training=False, + ) - train_states["vae_state"] = vae_state - state_shardings["vae_state_shardings"] = vae_state_mesh_shardings + train_states["vae_state"] = vae_state + state_shardings["vae_state_shardings"] = vae_state_mesh_shardings - text_encoder_state, text_encoder_state_mesh_shardings = self._time_and_log_call( - self.create_text_encoder_state, - # Arguments for create_text_encoder_state - pipeline=pipeline, - params=params, - checkpoint_item_name="text_encoder_state", - is_training=self.config.train_text_encoder, - ) - train_states["text_encoder_state"] = text_encoder_state - state_shardings["text_encoder_state_shardings"] = text_encoder_state_mesh_shardings - if hasattr(pipeline, "text_encoder_2"): - text_encoder_2_state, text_encoder_2_state_mesh_shardings = self._time_and_log_call( - self.create_text_encoder_2_state, - # Arguments for create_text_encoder_2_state + text_encoder_state, text_encoder_state_mesh_shardings = self._time_and_log_call( + self.create_text_encoder_state, + # Arguments for create_text_encoder_state pipeline=pipeline, params=params, - checkpoint_item_name="text_encoder_2_state", + checkpoint_item_name="text_encoder_state", is_training=self.config.train_text_encoder, ) - train_states["text_encoder_2_state"] = text_encoder_2_state - state_shardings["text_encoder_2_state_shardings"] = text_encoder_2_state_mesh_shardings - - # Create scheduler - noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) - pipeline.scheduler = noise_scheduler - params["scheduler"] = noise_scheduler_state - - # Calculate tflops - per_device_tflops = self.calculate_tflops(pipeline, params) - self.per_device_tflops = per_device_tflops - - # Load dataset - data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) - if self.config.dataset_type == "grain": - data_iterator = self._time_and_log_call(self.restore_data_iterator_state, data_iterator=data_iterator) - - unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self._time_and_log_call( - self.create_unet_state, - # ambiguous here, but if self.params.get("unet") doesn't exist - # Then its 1 of 2 scenarios: - # 1. unet state will be loaded directly from orbax - # 2. a new unet is being trained from scratch. - pipeline=pipeline, - params=params, - checkpoint_item_name="unet_state", - is_training=True, - ) - train_states["unet_state"] = unet_state - state_shardings["unet_state_shardings"] = unet_state_mesh_shardings + train_states["text_encoder_state"] = text_encoder_state + state_shardings["text_encoder_state_shardings"] = text_encoder_state_mesh_shardings + if hasattr(pipeline, "text_encoder_2"): + text_encoder_2_state, text_encoder_2_state_mesh_shardings = self._time_and_log_call( + self.create_text_encoder_2_state, + # Arguments for create_text_encoder_2_state + pipeline=pipeline, + params=params, + checkpoint_item_name="text_encoder_2_state", + is_training=self.config.train_text_encoder, + ) + train_states["text_encoder_2_state"] = text_encoder_2_state + state_shardings["text_encoder_2_state_shardings"] = text_encoder_2_state_mesh_shardings + + # Create scheduler + noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) + pipeline.scheduler = noise_scheduler + params["scheduler"] = noise_scheduler_state + + # Calculate tflops + per_device_tflops = self.calculate_tflops(pipeline, params) + self.per_device_tflops = per_device_tflops + + # Load dataset + data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) + if self.config.dataset_type == "grain": + data_iterator = self._time_and_log_call(self.restore_data_iterator_state, data_iterator=data_iterator) + + unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self._time_and_log_call( + self.create_unet_state, + # ambiguous here, but if self.params.get("unet") doesn't exist + # Then its 1 of 2 scenarios: + # 1. unet state will be loaded directly from orbax + # 2. a new unet is being trained from scratch. + pipeline=pipeline, + params=params, + checkpoint_item_name="unet_state", + is_training=True, + ) + train_states["unet_state"] = unet_state + state_shardings["unet_state_shardings"] = unet_state_mesh_shardings - data_shardings = self.get_data_shardings() - # Compile train_step - p_train_step = self._time_and_log_call( - self.compile_train_step, pipeline, params, train_states, state_shardings, data_shardings - ) - # Start training - train_states = self._time_and_log_call( - self.training_loop, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler - ) - # 6. save final checkpoint - # Hook - self._time_and_log_call(self.post_training_steps, pipeline, params, train_states) + data_shardings = self.get_data_shardings() + # Compile train_step + p_train_step = self._time_and_log_call( + self.compile_train_step, pipeline, params, train_states, state_shardings, data_shardings + ) + # Start training + train_states = self._time_and_log_call( + self.training_loop, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler + ) + # 6. save final checkpoint + # Hook + self._time_and_log_call(self.post_training_steps, pipeline, params, train_states)