File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -236,6 +236,7 @@ enable_profiler: False
236236# the iteration time a chance to stabilize.
237237skip_first_n_steps_for_profiler : 5
238238profiler_steps : 10
239+ profiler : " "
239240
240241# Generation parameters
241242prompt : " A magical castle in the middle of a forest, artistic drawing"
@@ -284,3 +285,5 @@ quantization: ''
284285quantization_local_shard_count : -1
285286use_qwix_quantization : False
286287compile_topology_num_slices : -1 # Number of target slices, set to a positive integer.
288+
289+ save_final_checkpoint : False
Original file line number Diff line number Diff line change 1919import jax
2020from absl import app
2121from maxdiffusion import (max_logging , pyconfig )
22+ from contextlib import contextmanager
2223
2324from maxdiffusion .train_utils import (
2425 validate_train_config ,
@@ -39,6 +40,24 @@ def main(argv: Sequence[str]) -> None:
3940 max_logging .log (f"Found { jax .device_count ()} devices." )
4041 train (config )
4142
43+ @contextmanager
44+ def transformer_engine_context ():
45+ """ If TransformerEngine is available, this context manager will provide the library with MaxText-specific details needed for correcct operation. """
46+ try :
47+ from transformer_engine .jax .sharding import global_shard_guard , MeshResource
48+ # Inform TransformerEngine of MaxText's physical mesh resources.
49+ mesh_resource = MeshResource (
50+ dp_resource = "data" ,
51+ tp_resource = "tensor" ,
52+ fsdp_resource = "fsdp" ,
53+ pp_resource = None ,
54+ cp_resource = None ,
55+ )
56+ with global_shard_guard (mesh_resource ):
57+ yield
58+ except ImportError :
59+ yield
4260
4361if __name__ == "__main__" :
44- app .run (main )
62+ with transformer_engine_context ():
63+ app .run (main )
You can’t perform that action at this time.
0 commit comments