From 48dfc10820fe123b87c2a96cb9e90393ac29863a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 18 Jul 2025 22:07:41 +0000 Subject: [PATCH] fixes flux training. --- src/maxdiffusion/configs/base_flux_dev.yml | 1 + src/maxdiffusion/generate_flux_pipeline.py | 159 +++++++++++---------- src/maxdiffusion/train_flux.py | 7 +- src/maxdiffusion/trainers/flux_trainer.py | 131 ++++++++--------- 4 files changed, 148 insertions(+), 150 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 220a5bb2c..d6d003391 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -228,6 +228,7 @@ enable_profiler: False # the iteration time a chance to stabilize. skip_first_n_steps_for_profiler: 5 profiler_steps: 10 +profiler: "" # Generation parameters prompt: "A magical castle in the middle of a forest, artistic drawing" diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index 6ee469728..e6b8d4e2c 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -33,85 +33,86 @@ def run(config): from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT") - pipeline, params = checkpoint_loader.load_checkpoint() - - if not params: - ## VAE - weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) - unboxed_abstract_state, _, _ = max_utils.get_abstract_state( - pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False - ) - # load unet params from orbax checkpoint - vae_params = load_params_from_path( - config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state" - ) - - vae_state = {"params": vae_params} - - ## Flux - weights_init_fn = functools.partial( - pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length - ) - - unboxed_abstract_state, _, _ = max_utils.get_abstract_state( - pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False - ) - # load unet params from orbax checkpoint - flux_params = load_params_from_path( - config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state" - ) - flux_state = {"params": flux_params} - else: - weights_init_fn = functools.partial( - pipeline.flux.init_weights, - rngs=checkpoint_loader.rng, - max_sequence_length=config.max_sequence_length, - eval_only=False, - ) - transformer_state, flux_state_shardings = setup_initial_state( - model=pipeline.flux, - tx=None, - config=config, - mesh=checkpoint_loader.mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, - ) - transformer_state = transformer_state.replace(params=params["flux_transformer_params"]) - transformer_state = jax.device_put(transformer_state, flux_state_shardings) - - weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) - vae_state, _ = setup_initial_state( - model=pipeline.vae, - tx=None, - config=config, - mesh=checkpoint_loader.mesh, - weights_init_fn=weights_init_fn, - model_params=params["flux_vae"], - training=False, - ) - - vae_state = {"params": vae_state.params} - flux_state = {"params": transformer_state.params} - - t0 = time.perf_counter() - with ExitStack(): - imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() - t1 = time.perf_counter() - max_logging.log(f"Compile time: {t1 - t0:.1f}s.") - - t0 = time.perf_counter() - with ExitStack(): - imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() - imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) - t1 = time.perf_counter() - max_logging.log(f"Inference time: {t1 - t0:.1f}s.") - imgs = np.array(imgs) - imgs = (imgs * 0.5 + 0.5).clip(0, 1) - imgs = np.transpose(imgs, (0, 2, 3, 1)) - imgs = np.uint8(imgs * 255) - for i, image in enumerate(imgs): - Image.fromarray(image).save(f"flux_{i}.png") + mesh = checkpoint_loader.mesh + with mesh: + pipeline, params = checkpoint_loader.load_checkpoint() + + if not params: + ## VAE + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + vae_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state" + ) + + vae_state = {"params": vae_params} + + ## Flux + weights_init_fn = functools.partial( + pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length + ) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + flux_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state" + ) + flux_state = {"params": flux_params} + else: + weights_init_fn = functools.partial( + pipeline.flux.init_weights, + rngs=checkpoint_loader.rng, + max_sequence_length=config.max_sequence_length, + eval_only=False, + ) + transformer_state, flux_state_shardings = setup_initial_state( + model=pipeline.flux, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=params["flux_transformer_params"]) + transformer_state = jax.device_put(transformer_state, flux_state_shardings) + + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + vae_state, _ = setup_initial_state( + model=pipeline.vae, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params["flux_vae"], + training=False, + ) + + vae_state = {"params": vae_state.params} + flux_state = {"params": transformer_state.params} + + t0 = time.perf_counter() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Compile time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{i}.png") return imgs diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index e3b161039..6f7e940f2 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -18,11 +18,7 @@ import jax from absl import app -from maxdiffusion import ( - max_logging, - pyconfig, - mllog_utils, -) +from maxdiffusion import (max_logging, pyconfig) from maxdiffusion.train_utils import ( validate_train_config, @@ -39,7 +35,6 @@ def train(config): def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) config = pyconfig.config - mllog_utils.train_init_start(config) validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") train(config) diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index f3ca99cb2..74b4f259e 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -80,70 +80,71 @@ def start_training(self): # Hook # self.pre_training_steps() # Load checkpoint - will load or create states - pipeline, params = self.load_checkpoint() - - # create train states - train_states = {} - state_shardings = {} - - # move params to accelerator - encoders_sharding = NamedSharding(self.mesh, P(None)) - partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) - pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) - pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) - pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params) - pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params) - - vae_state, vae_state_mesh_shardings = self.create_vae_state( - pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False - ) - train_states[VAE_STATE_KEY] = vae_state - state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings - - # Load dataset - data_iterator = self.load_dataset(pipeline, params, train_states) - if self.config.dataset_type == "grain": - data_iterator = self.restore_data_iterator_state(data_iterator) - - # don't need this anymore, clear some memory. - del pipeline.t5_encoder - - # evaluate shapes - - flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( - # ambiguous here, but if params=None - # Then its 1 of 2 scenarios: - # 1. flux state will be loaded directly from orbax - # 2. a new flux is being trained from scratch. - pipeline=pipeline, - params=None, # Params are loaded inside create_flux_state - checkpoint_item_name=FLUX_STATE_KEY, - is_training=True, - ) - flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) - train_states[FLUX_STATE_KEY] = flux_state - state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings - # self.post_training_steps(pipeline, params, train_states, msg="before_training") - - # Create scheduler - noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) - pipeline.scheduler = noise_scheduler - train_states["scheduler"] = noise_scheduler_state - - # Calculate tflops - per_device_tflops = self.calculate_tflops(pipeline) - self.per_device_tflops = per_device_tflops - - data_shardings = self.get_data_shardings() - # Compile train_step - p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) - # Start training - train_states = self.training_loop( - p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler - ) - # 6. save final checkpoint - # Hook - self.post_training_steps(pipeline, params, train_states, "after_training") + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + pipeline, params = self.load_checkpoint() + + # create train states + train_states = {} + state_shardings = {} + + # move params to accelerator + encoders_sharding = NamedSharding(self.mesh, P(None)) + partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) + pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) + pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params) + + vae_state, vae_state_mesh_shardings = self.create_vae_state( + pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False + ) + train_states[VAE_STATE_KEY] = vae_state + state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings + + # Load dataset + data_iterator = self.load_dataset(pipeline, params, train_states) + if self.config.dataset_type == "grain": + data_iterator = self.restore_data_iterator_state(data_iterator) + + # don't need this anymore, clear some memory. + del pipeline.t5_encoder + + # evaluate shapes + + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( + # ambiguous here, but if params=None + # Then its 1 of 2 scenarios: + # 1. flux state will be loaded directly from orbax + # 2. a new flux is being trained from scratch. + pipeline=pipeline, + params=None, # Params are loaded inside create_flux_state + checkpoint_item_name=FLUX_STATE_KEY, + is_training=True, + ) + flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) + train_states[FLUX_STATE_KEY] = flux_state + state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings + # self.post_training_steps(pipeline, params, train_states, msg="before_training") + + # Create scheduler + noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) + pipeline.scheduler = noise_scheduler + train_states["scheduler"] = noise_scheduler_state + + # Calculate tflops + per_device_tflops = self.calculate_tflops(pipeline) + self.per_device_tflops = per_device_tflops + + data_shardings = self.get_data_shardings() + # Compile train_step + p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) + # Start training + train_states = self.training_loop( + p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler + ) + # 6. save final checkpoint + # Hook + self.post_training_steps(pipeline, params, train_states, "after_training") def get_shaped_batch(self, config, pipeline=None): """Return the shape of the batch - this is what eval_shape would return for the @@ -349,7 +350,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera example_batch = load_next_batch(data_iterator, example_batch, self.config) example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} - if self.config.profiler == 'nsys': + if self.config.profiler == "nsys": with self.mesh: flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) else: