Skip to content

Commit 6f536d8

Browse files
tf processing on cpu
1 parent 296e956 commit 6f536d8

3 files changed

Lines changed: 29 additions & 25 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -404,21 +404,18 @@ def setup_initial_state(
404404
state = state[checkpoint_item]
405405
if not state:
406406
max_logging.log(f"Could not find the item in orbax, creating state...")
407-
init_train_state_partial = functools.partial(
408-
init_train_state,
409-
model=model,
410-
tx=tx,
411-
weights_init_fn=weights_init_fn,
412-
params=model_params,
413-
training=training,
414-
eval_only=False,
407+
state = init_train_state(
408+
model=model,
409+
tx=tx,
410+
weights_init_fn=weights_init_fn,
411+
params=model_params,
412+
training=training,
413+
eval_only=False
415414
)
415+
if model_params:
416+
state = state.replace(params=model_params)
416417

417-
state = jax.jit(
418-
init_train_state_partial,
419-
in_shardings=None,
420-
out_shardings=state_mesh_shardings,
421-
)()
418+
state = jax.device_put(state, state_mesh_shardings)
422419

423420
state = unbox_logicallypartioned_trainstate(state)
424421

src/maxdiffusion/train_sdxl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,10 @@ def main(argv: Sequence[str]) -> None:
4646

4747

4848
if __name__ == "__main__":
49+
import os
50+
import tensorflow as tf
51+
import torch
52+
tf.config.set_visible_devices([], "GPU")
53+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
54+
torch.set_default_device("cpu")
4955
app.run(main)

src/maxdiffusion/trainers/base_stable_diffusion_trainer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,6 @@ def start_training(self):
8383
# create train states
8484
train_states = {}
8585
state_shardings = {}
86-
unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state(
87-
# ambiguous here, but if self.params.get("unet") doesn't exist
88-
# Then its 1 of 2 scenarios:
89-
# 1. unet state will be loaded directly from orbax
90-
# 2. a new unet is being trained from scratch.
91-
pipeline=pipeline,
92-
params=params,
93-
checkpoint_item_name="unet_state",
94-
is_training=True,
95-
)
96-
train_states["unet_state"] = unet_state
97-
state_shardings["unet_state_shardings"] = unet_state_mesh_shardings
9886
vae_state, vae_state_mesh_shardings = self.create_vae_state(
9987
pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False
10088
)
@@ -131,6 +119,19 @@ def start_training(self):
131119
if self.config.dataset_type == "grain":
132120
data_iterator = self.restore_data_iterator_state(data_iterator)
133121

122+
unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state(
123+
# ambiguous here, but if self.params.get("unet") doesn't exist
124+
# Then its 1 of 2 scenarios:
125+
# 1. unet state will be loaded directly from orbax
126+
# 2. a new unet is being trained from scratch.
127+
pipeline=pipeline,
128+
params=params,
129+
checkpoint_item_name="unet_state",
130+
is_training=True,
131+
)
132+
train_states["unet_state"] = unet_state
133+
state_shardings["unet_state_shardings"] = unet_state_mesh_shardings
134+
134135
data_shardings = self.get_data_shardings()
135136
# Compile train_step
136137
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)

0 commit comments

Comments
 (0)