Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ enable_profiler: False
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: ""

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
Expand Down Expand Up @@ -284,3 +285,5 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

save_final_checkpoint: False
Comment thread
coolkp marked this conversation as resolved.
21 changes: 20 additions & 1 deletion src/maxdiffusion/train_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
from absl import app
from maxdiffusion import (max_logging, pyconfig)
from contextlib import contextmanager

from maxdiffusion.train_utils import (
validate_train_config,
Expand All @@ -39,6 +40,24 @@ def main(argv: Sequence[str]) -> None:
max_logging.log(f"Found {jax.device_count()} devices.")
train(config)

@contextmanager
def transformer_engine_context():
""" If TransformerEngine is available, this context manager will provide the library with MaxText-specific details needed for correcct operation. """
try:
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
# Inform TransformerEngine of MaxText's physical mesh resources.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maxdiffusion

mesh_resource = MeshResource(
dp_resource = "data",
tp_resource = "tensor",
fsdp_resource = "fsdp",
pp_resource = None,
cp_resource = None,
)
with global_shard_guard(mesh_resource):
yield
except ImportError:
yield

if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed for sdxl too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's needed for any model that uses TE to let TE know the physical mesh axes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then can you add in the train_sdxl as well?

app.run(main)