Skip to content

Commit 0aa463b

Browse files
committed
add TE context to sdxl
1 parent 99200be commit 0aa463b

3 files changed

Lines changed: 24 additions & 21 deletions

File tree

src/maxdiffusion/train_flux.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import jax
2020
from absl import app
2121
from maxdiffusion import (max_logging, pyconfig)
22-
from contextlib import contextmanager
2322

2423
from maxdiffusion.train_utils import (
2524
validate_train_config,
25+
transformer_engine_context,
2626
)
2727

2828

@@ -40,24 +40,6 @@ def main(argv: Sequence[str]) -> None:
4040
max_logging.log(f"Found {jax.device_count()} devices.")
4141
train(config)
4242

43-
@contextmanager
44-
def transformer_engine_context():
45-
""" If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
46-
try:
47-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
48-
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
49-
mesh_resource = MeshResource(
50-
dp_resource = "data",
51-
tp_resource = "tensor",
52-
fsdp_resource = "fsdp",
53-
pp_resource = None,
54-
cp_resource = None,
55-
)
56-
with global_shard_guard(mesh_resource):
57-
yield
58-
except ImportError:
59-
yield
60-
6143
if __name__ == "__main__":
6244
with transformer_engine_context():
6345
app.run(main)

src/maxdiffusion/train_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from maxdiffusion.train_utils import (
2929
validate_train_config,
30+
transformer_engine_context,
3031
)
3132

3233

@@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None:
5152
tf.config.set_visible_devices([], "GPU")
5253
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
5354
torch.set_default_device("cpu")
54-
app.run(main)
55+
with transformer_engine_context():
56+
app.run(main)

src/maxdiffusion/train_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import queue
2121

2222
from maxdiffusion import max_utils, max_logging
23-
23+
from contextlib import contextmanager
2424

2525
def get_first_step(state):
2626
return int(state.step)
@@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps):
196196
weights[bias_indices] *= timestep_bias_config["multiplier"]
197197
weights /= weights.sum()
198198
return jnp.array(weights)
199+
200+
201+
@contextmanager
202+
def transformer_engine_context():
203+
""" If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
204+
try:
205+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
206+
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
207+
mesh_resource = MeshResource(
208+
dp_resource = "data",
209+
tp_resource = "tensor",
210+
fsdp_resource = "fsdp",
211+
pp_resource = None,
212+
cp_resource = None,
213+
)
214+
with global_shard_guard(mesh_resource):
215+
yield
216+
except ImportError:
217+
yield

0 commit comments

Comments
 (0)