Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ enable_profiler: False
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: ""

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
Expand Down
159 changes: 80 additions & 79 deletions src/maxdiffusion/generate_flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,85 +33,86 @@ def run(config):
from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer

checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT")
pipeline, params = checkpoint_loader.load_checkpoint()

if not params:
## VAE
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
vae_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
)

vae_state = {"params": vae_params}

## Flux
weights_init_fn = functools.partial(
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
)

unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
flux_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
)
flux_state = {"params": flux_params}
else:
weights_init_fn = functools.partial(
pipeline.flux.init_weights,
rngs=checkpoint_loader.rng,
max_sequence_length=config.max_sequence_length,
eval_only=False,
)
transformer_state, flux_state_shardings = setup_initial_state(
model=pipeline.flux,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
transformer_state = jax.device_put(transformer_state, flux_state_shardings)

weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
vae_state, _ = setup_initial_state(
model=pipeline.vae,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params["flux_vae"],
training=False,
)

vae_state = {"params": vae_state.params}
flux_state = {"params": transformer_state.params}

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")
mesh = checkpoint_loader.mesh
with mesh:
pipeline, params = checkpoint_loader.load_checkpoint()

if not params:
## VAE
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
vae_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
)

vae_state = {"params": vae_params}

## Flux
weights_init_fn = functools.partial(
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
)
# load unet params from orbax checkpoint
flux_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
)
flux_state = {"params": flux_params}
else:
weights_init_fn = functools.partial(
pipeline.flux.init_weights,
rngs=checkpoint_loader.rng,
max_sequence_length=config.max_sequence_length,
eval_only=False,
)
transformer_state, flux_state_shardings = setup_initial_state(
model=pipeline.flux,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
transformer_state = jax.device_put(transformer_state, flux_state_shardings)

weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
vae_state, _ = setup_initial_state(
model=pipeline.vae,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params["flux_vae"],
training=False,
)

vae_state = {"params": vae_state.params}
flux_state = {"params": transformer_state.params}

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
imgs = np.array(imgs)
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
imgs = np.transpose(imgs, (0, 2, 3, 1))
imgs = np.uint8(imgs * 255)
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")

return imgs

Expand Down
7 changes: 1 addition & 6 deletions src/maxdiffusion/train_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

import jax
from absl import app
from maxdiffusion import (
max_logging,
pyconfig,
mllog_utils,
)
from maxdiffusion import (max_logging, pyconfig)

from maxdiffusion.train_utils import (
validate_train_config,
Expand All @@ -39,7 +35,6 @@ def train(config):
def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
config = pyconfig.config
mllog_utils.train_init_start(config)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
train(config)
Expand Down
131 changes: 66 additions & 65 deletions src/maxdiffusion/trainers/flux_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,70 +80,71 @@ def start_training(self):
# Hook
# self.pre_training_steps()
# Load checkpoint - will load or create states
pipeline, params = self.load_checkpoint()

# create train states
train_states = {}
state_shardings = {}

# move params to accelerator
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)

vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
)
train_states[VAE_STATE_KEY] = vae_state
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

# don't need this anymore, clear some memory.
del pipeline.t5_encoder

# evaluate shapes

flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
# ambiguous here, but if params=None
# Then its 1 of 2 scenarios:
# 1. flux state will be loaded directly from orbax
# 2. a new flux is being trained from scratch.
pipeline=pipeline,
params=None, # Params are loaded inside create_flux_state
checkpoint_item_name=FLUX_STATE_KEY,
is_training=True,
)
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
train_states[FLUX_STATE_KEY] = flux_state
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
# self.post_training_steps(pipeline, params, train_states, msg="before_training")

# Create scheduler
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
pipeline.scheduler = noise_scheduler
train_states["scheduler"] = noise_scheduler_state

# Calculate tflops
per_device_tflops = self.calculate_tflops(pipeline)
self.per_device_tflops = per_device_tflops

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
# Start training
train_states = self.training_loop(
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
)
# 6. save final checkpoint
# Hook
self.post_training_steps(pipeline, params, train_states, "after_training")
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
pipeline, params = self.load_checkpoint()

# create train states
train_states = {}
state_shardings = {}

# move params to accelerator
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)

vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
)
train_states[VAE_STATE_KEY] = vae_state
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

# don't need this anymore, clear some memory.
del pipeline.t5_encoder

# evaluate shapes

flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
# ambiguous here, but if params=None
# Then its 1 of 2 scenarios:
# 1. flux state will be loaded directly from orbax
# 2. a new flux is being trained from scratch.
pipeline=pipeline,
params=None, # Params are loaded inside create_flux_state
checkpoint_item_name=FLUX_STATE_KEY,
is_training=True,
)
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
train_states[FLUX_STATE_KEY] = flux_state
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
# self.post_training_steps(pipeline, params, train_states, msg="before_training")

# Create scheduler
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
pipeline.scheduler = noise_scheduler
train_states["scheduler"] = noise_scheduler_state

# Calculate tflops
per_device_tflops = self.calculate_tflops(pipeline)
self.per_device_tflops = per_device_tflops

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
# Start training
train_states = self.training_loop(
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
)
# 6. save final checkpoint
# Hook
self.post_training_steps(pipeline, params, train_states, "after_training")

def get_shaped_batch(self, config, pipeline=None):
"""Return the shape of the batch - this is what eval_shape would return for the
Expand Down Expand Up @@ -349,7 +350,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
example_batch = load_next_batch(data_iterator, example_batch, self.config)
example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()}

if self.config.profiler == 'nsys':
if self.config.profiler == "nsys":
with self.mesh:
flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs)
else:
Expand Down
Loading