File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1919import jax
2020from absl import app
2121from maxdiffusion import (max_logging , pyconfig )
22- from contextlib import contextmanager
2322
2423from maxdiffusion .train_utils import (
2524 validate_train_config ,
25+ transformer_engine_context ,
2626)
2727
2828
@@ -40,24 +40,6 @@ def main(argv: Sequence[str]) -> None:
4040 max_logging .log (f"Found { jax .device_count ()} devices." )
4141 train (config )
4242
43- @contextmanager
44- def transformer_engine_context ():
45- """ If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
46- try :
47- from transformer_engine .jax .sharding import global_shard_guard , MeshResource
48- # Inform TransformerEngine of MaxDiffusion's physical mesh resources.
49- mesh_resource = MeshResource (
50- dp_resource = "data" ,
51- tp_resource = "tensor" ,
52- fsdp_resource = "fsdp" ,
53- pp_resource = None ,
54- cp_resource = None ,
55- )
56- with global_shard_guard (mesh_resource ):
57- yield
58- except ImportError :
59- yield
60-
6143if __name__ == "__main__" :
6244 with transformer_engine_context ():
6345 app .run (main )
Original file line number Diff line number Diff line change 2727
2828from maxdiffusion .train_utils import (
2929 validate_train_config ,
30+ transformer_engine_context ,
3031)
3132
3233
@@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None:
5152 tf .config .set_visible_devices ([], "GPU" )
5253 os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
5354 torch .set_default_device ("cpu" )
54- app .run (main )
55+ with transformer_engine_context ():
56+ app .run (main )
Original file line number Diff line number Diff line change 2020import queue
2121
2222from maxdiffusion import max_utils , max_logging
23-
23+ from contextlib import contextmanager
2424
2525def get_first_step (state ):
2626 return int (state .step )
@@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps):
196196 weights [bias_indices ] *= timestep_bias_config ["multiplier" ]
197197 weights /= weights .sum ()
198198 return jnp .array (weights )
199+
200+
201+ @contextmanager
202+ def transformer_engine_context ():
203+ """ If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
204+ try :
205+ from transformer_engine .jax .sharding import global_shard_guard , MeshResource
206+ # Inform TransformerEngine of MaxDiffusion's physical mesh resources.
207+ mesh_resource = MeshResource (
208+ dp_resource = "data" ,
209+ tp_resource = "tensor" ,
210+ fsdp_resource = "fsdp" ,
211+ pp_resource = None ,
212+ cp_resource = None ,
213+ )
214+ with global_shard_guard (mesh_resource ):
215+ yield
216+ except ImportError :
217+ yield
You can’t perform that action at this time.
0 commit comments