Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ enable_profiler: False
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: ""

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
Expand Down Expand Up @@ -284,3 +285,5 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

save_final_checkpoint: False
Comment thread
coolkp marked this conversation as resolved.
5 changes: 3 additions & 2 deletions src/maxdiffusion/train_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from maxdiffusion.train_utils import (
validate_train_config,
transformer_engine_context,
)


Expand All @@ -39,6 +40,6 @@ def main(argv: Sequence[str]) -> None:
max_logging.log(f"Found {jax.device_count()} devices.")
train(config)


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed for sdxl too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's needed for any model that uses TE to let TE know the physical mesh axes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then can you add in the train_sdxl as well?

app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from maxdiffusion.train_utils import (
validate_train_config,
transformer_engine_context,
)


Expand All @@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None:
tf.config.set_visible_devices([], "GPU")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
torch.set_default_device("cpu")
app.run(main)
with transformer_engine_context():
app.run(main)
21 changes: 20 additions & 1 deletion src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import queue

from maxdiffusion import max_utils, max_logging

from contextlib import contextmanager

def get_first_step(state):
return int(state.step)
Expand Down Expand Up @@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps):
weights[bias_indices] *= timestep_bias_config["multiplier"]
weights /= weights.sum()
return jnp.array(weights)


@contextmanager
def transformer_engine_context():
""" If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
try:
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
mesh_resource = MeshResource(
dp_resource = "data",
tp_resource = "tensor",
fsdp_resource = "fsdp",
pp_resource = None,
cp_resource = None,
)
with global_shard_guard(mesh_resource):
yield
except ImportError:
yield
Loading