Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ enable_profiler: False
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: ""

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
Expand Down
159 changes: 80 additions & 79 deletions src/maxdiffusion/generate_flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,85 +33,86 @@ def run(config):
from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer

checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT")
pipeline, params = checkpoint_loader.load_checkpoint()

if not params:
## VAE
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
vae_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
)

vae_state = {"params": vae_params}

## Flux
weights_init_fn = functools.partial(
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
)

unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
flux_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
)
flux_state = {"params": flux_params}
else:
weights_init_fn = functools.partial(
pipeline.flux.init_weights,
rngs=checkpoint_loader.rng,
max_sequence_length=config.max_sequence_length,
eval_only=False,
)
transformer_state, flux_state_shardings = setup_initial_state(
model=pipeline.flux,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
transformer_state = jax.device_put(transformer_state, flux_state_shardings)

weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
vae_state, _ = setup_initial_state(
model=pipeline.vae,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params["flux_vae"],
training=False,
)

vae_state = {"params": vae_state.params}
flux_state = {"params": transformer_state.params}

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")
mesh = checkpoint_loader.mesh
with mesh:
pipeline, params = checkpoint_loader.load_checkpoint()

if not params:
## VAE
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
vae_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
)

vae_state = {"params": vae_params}

## Flux
weights_init_fn = functools.partial(
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
flux_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
)
flux_state = {"params": flux_params}
else:
weights_init_fn = functools.partial(
pipeline.flux.init_weights,
rngs=checkpoint_loader.rng,
max_sequence_length=config.max_sequence_length,
eval_only=False,
)
transformer_state, flux_state_shardings = setup_initial_state(
model=pipeline.flux,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
transformer_state = jax.device_put(transformer_state, flux_state_shardings)

weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
vae_state, _ = setup_initial_state(
model=pipeline.vae,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params["flux_vae"],
training=False,
)

vae_state = {"params": vae_state.params}
flux_state = {"params": transformer_state.params}

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")

return imgs

Expand Down
5 changes: 1 addition & 4 deletions src/maxdiffusion/train_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@

import jax
from absl import app
from maxdiffusion import (
max_logging,
pyconfig,
)
from maxdiffusion import (max_logging, pyconfig)

from maxdiffusion.train_utils import (
validate_train_config,
Expand Down
131 changes: 66 additions & 65 deletions src/maxdiffusion/trainers/flux_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,70 +80,71 @@ def start_training(self):
# Hook
# self.pre_training_steps()
# Load checkpoint - will load or create states
pipeline, params = self.load_checkpoint()

# create train states
train_states = {}
state_shardings = {}

# move params to accelerator
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)

vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
)
train_states[VAE_STATE_KEY] = vae_state
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

# don't need this anymore, clear some memory.
del pipeline.t5_encoder

# evaluate shapes

flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
# ambiguous here, but if params=None
# Then its 1 of 2 scenarios:
# 1. flux state will be loaded directly from orbax
# 2. a new flux is being trained from scratch.
pipeline=pipeline,
params=None, # Params are loaded inside create_flux_state
checkpoint_item_name=FLUX_STATE_KEY,
is_training=True,
)
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
train_states[FLUX_STATE_KEY] = flux_state
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
# self.post_training_steps(pipeline, params, train_states, msg="before_training")

# Create scheduler
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
pipeline.scheduler = noise_scheduler
train_states["scheduler"] = noise_scheduler_state

# Calculate tflops
per_device_tflops = self.calculate_tflops(pipeline)
self.per_device_tflops = per_device_tflops

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
# Start training
train_states = self.training_loop(
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
)
# 6. save final checkpoint
# Hook
self.post_training_steps(pipeline, params, train_states, "after_training")
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
pipeline, params = self.load_checkpoint()

# create train states
train_states = {}
state_shardings = {}

# move params to accelerator
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)

vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
)
train_states[VAE_STATE_KEY] = vae_state
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

# don't need this anymore, clear some memory.
del pipeline.t5_encoder

# evaluate shapes

flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
# ambiguous here, but if params=None
# Then its 1 of 2 scenarios:
# 1. flux state will be loaded directly from orbax
# 2. a new flux is being trained from scratch.
pipeline=pipeline,
params=None, # Params are loaded inside create_flux_state
checkpoint_item_name=FLUX_STATE_KEY,
is_training=True,
)
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
train_states[FLUX_STATE_KEY] = flux_state
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
# self.post_training_steps(pipeline, params, train_states, msg="before_training")

# Create scheduler
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
pipeline.scheduler = noise_scheduler
train_states["scheduler"] = noise_scheduler_state

# Calculate tflops
per_device_tflops = self.calculate_tflops(pipeline)
self.per_device_tflops = per_device_tflops

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
# Start training
train_states = self.training_loop(
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
)
# 6. save final checkpoint
# Hook
self.post_training_steps(pipeline, params, train_states, "after_training")

def get_shaped_batch(self, config, pipeline=None):
"""Return the shape of the batch - this is what eval_shape would return for the
Expand Down Expand Up @@ -349,7 +350,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
example_batch = load_next_batch(data_iterator, example_batch, self.config)
example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()}

if self.config.profiler == 'nsys':
if self.config.profiler == "nsys":
with self.mesh:
flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs)
else:
Expand Down
Loading