Skip to content

Commit e38675b

Browse files
committed
fix for gpu
1 parent 3a9d12b commit e38675b

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ enable_profiler: False
236236
# the iteration time a chance to stabilize.
237237
skip_first_n_steps_for_profiler: 5
238238
profiler_steps: 10
239+
profiler: ""
239240

240241
# Generation parameters
241242
prompt: "A magical castle in the middle of a forest, artistic drawing"
@@ -284,3 +285,5 @@ quantization: ''
284285
quantization_local_shard_count: -1
285286
use_qwix_quantization: False
286287
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
288+
289+
save_final_checkpoint: False

src/maxdiffusion/train_flux.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from absl import app
2121
from maxdiffusion import (max_logging, pyconfig)
22+
from contextlib import contextmanager
2223

2324
from maxdiffusion.train_utils import (
2425
validate_train_config,
@@ -39,6 +40,24 @@ def main(argv: Sequence[str]) -> None:
3940
max_logging.log(f"Found {jax.device_count()} devices.")
4041
train(config)
4142

43+
@contextmanager
44+
def transformer_engine_context():
45+
""" If TransformerEngine is available, this context manager will provide the library with MaxText-specific details needed for correcct operation. """
46+
try:
47+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
48+
# Inform TransformerEngine of MaxText's physical mesh resources.
49+
mesh_resource = MeshResource(
50+
dp_resource = "data",
51+
tp_resource = "tensor",
52+
fsdp_resource = "fsdp",
53+
pp_resource = None,
54+
cp_resource = None,
55+
)
56+
with global_shard_guard(mesh_resource):
57+
yield
58+
except ImportError:
59+
yield
4260

4361
if __name__ == "__main__":
44-
app.run(main)
62+
with transformer_engine_context():
63+
app.run(main)

0 commit comments

Comments
 (0)