diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index e4b544f53..300ec0395 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -236,6 +236,7 @@ enable_profiler: False # the iteration time a chance to stabilize. skip_first_n_steps_for_profiler: 5 profiler_steps: 10 +profiler: "" # Generation parameters prompt: "A magical castle in the middle of a forest, artistic drawing" @@ -284,3 +285,5 @@ quantization: '' quantization_local_shard_count: -1 use_qwix_quantization: False compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +save_final_checkpoint: False \ No newline at end of file diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index 6f7e940f2..fef182062 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -22,6 +22,7 @@ from maxdiffusion.train_utils import ( validate_train_config, + transformer_engine_context, ) @@ -39,6 +40,6 @@ def main(argv: Sequence[str]) -> None: max_logging.log(f"Found {jax.device_count()} devices.") train(config) - if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 64b0cd3bc..60170a853 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -27,6 +27,7 @@ from maxdiffusion.train_utils import ( validate_train_config, + transformer_engine_context, ) @@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None: tf.config.set_visible_devices([], "GPU") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" torch.set_default_device("cpu") - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index e3e75971c..8f986a90a 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -20,7 +20,7 @@ import queue from maxdiffusion import max_utils, max_logging - +from contextlib import contextmanager def get_first_step(state): return int(state.step) @@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps): weights[bias_indices] *= timestep_bias_config["multiplier"] weights /= weights.sum() return jnp.array(weights) + + +@contextmanager +def transformer_engine_context(): + """ If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """ + try: + from transformer_engine.jax.sharding import global_shard_guard, MeshResource + # Inform TransformerEngine of MaxDiffusion's physical mesh resources. + mesh_resource = MeshResource( + dp_resource = "data", + tp_resource = "tensor", + fsdp_resource = "fsdp", + pp_resource = None, + cp_resource = None, + ) + with global_shard_guard(mesh_resource): + yield + except ImportError: + yield