Skip to content

Commit fcafb4a

Browse files
committed
flag to jit initializers
1 parent 0db81e3 commit fcafb4a

9 files changed

Lines changed: 43 additions & 6 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ After installation completes, run the training script.
8888
First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).
8989

9090
```bash
91-
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=float16 activations_dtype=float16 per_device_batch_size=1 attention="cudnn_flash_te"
91+
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
9292
```
9393

9494
To generate images with a trained checkpoint, run:

src/maxdiffusion/configs/base14.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ activations_dtype: 'bfloat16'
3737
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
3838
# at the cost of time.
3939
precision: "DEFAULT"
40+
41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4045
# Set true to load weights from pytorch
4146
from_pt: False
4247
split_head_dim: True

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ activations_dtype: 'bfloat16'
3737
# at the cost of time.
3838
precision: "DEFAULT"
3939

40+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
41+
# It must be True for multi-host.
42+
jit_initializers: True
43+
4044
# Set true to load weights from pytorch
4145
from_pt: False
4246
split_head_dim: True

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: True
4347
split_head_dim: True

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ activations_dtype: 'bfloat16'
5151
# at the cost of time.
5252
precision: "DEFAULT"
5353

54+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
55+
# It must be True for multi-host.
56+
jit_initializers: True
57+
5458
# Set true to load weights from pytorch
5559
from_pt: True
5660
split_head_dim: True

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ activations_dtype: 'bfloat16'
5050
# at the cost of time.
5151
precision: "DEFAULT"
5252

53+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
54+
# It must be True for multi-host.
55+
jit_initializers: True
56+
5357
# Set true to load weights from pytorch
5458
from_pt: True
5559
split_head_dim: True

src/maxdiffusion/configs/base_xl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: False
4347
split_head_dim: True

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ activations_dtype: 'bfloat16'
3636
# at the cost of time.
3737
precision: "DEFAULT"
3838

39+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
40+
# It must be True for multi-host.
41+
jit_initializers: True
42+
3943
# Set true to load weights from pytorch
4044
from_pt: False
4145
split_head_dim: True

src/maxdiffusion/max_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,19 @@ def setup_initial_state(
413413
training=training,
414414
eval_only=False,
415415
)
416-
state = jax.jit(
417-
init_train_state_partial,
418-
in_shardings=None,
419-
out_shardings=state_mesh_shardings,
420-
)()
416+
if config.jit_initializers:
417+
state = jax.jit(
418+
init_train_state_partial,
419+
in_shardings=None,
420+
out_shardings=state_mesh_shardings,
421+
)()
422+
else:
423+
state = init_train_state_partial()
424+
if model_params:
425+
state = state.replace(params=model_params)
426+
state = jax.device_put(state, state_mesh_shardings)
427+
if model_params:
428+
state = state.replace(params=model_params)
421429

422430
state = unbox_logicallypartioned_trainstate(state)
423431

0 commit comments

Comments
 (0)