Skip to content

Commit 2228609

Browse files
authored
Fix device put error in unet state and text encoder 2 state (#155)
Signed-off-by: kunjan patel <kunjanp@google.com>
1 parent 296e956 commit 2228609

3 files changed

Lines changed: 6 additions & 11 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ tensorflow-datasets>=4.9.6
2424
ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.2
27+
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jaxlib>=0.4.30
1414
Jinja2
1515
opencv-python-headless==4.10.0.84
1616
optax>=0.2.3
17-
orbax-checkpoint==0.10.2
17+
orbax-checkpoint==0.10.3
1818
parameterized
1919
Pillow
2020
pyink

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,11 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888
config=self.config,
8989
mesh=self.mesh,
9090
weights_init_fn=weights_init_fn,
91-
model_params=None,
91+
model_params=None if self.config.train_new_unet else params.get("unet", None),
9292
checkpoint_manager=self.checkpoint_manager,
9393
checkpoint_item=checkpoint_item_name,
9494
training=is_training,
9595
)
96-
if not self.config.train_new_unet:
97-
unet_state = unet_state.replace(params=params.get("unet", None))
98-
unet_state = jax.device_put(unet_state, state_mesh_shardings)
9996
return unet_state, state_mesh_shardings, learning_rate_scheduler
10097

10198
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -153,20 +150,18 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
153150
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
154151
)
155152

156-
state, state_mesh_shardings = max_utils.setup_initial_state(
153+
# state, state_mesh_shardings =
154+
return max_utils.setup_initial_state(
157155
model=pipeline.text_encoder_2,
158156
tx=tx,
159157
config=self.config,
160158
mesh=self.mesh,
161159
weights_init_fn=weights_init_fn,
162-
model_params=None,
160+
model_params=params.get("text_encoder_2", None),
163161
checkpoint_manager=self.checkpoint_manager,
164162
checkpoint_item=checkpoint_item_name,
165163
training=is_training,
166164
)
167-
state = state.replace(params=params.get("text_encoder_2", None))
168-
state = jax.device_put(state, state_mesh_shardings)
169-
return state, state_mesh_shardings
170165

171166
def restore_data_iterator_state(self, data_iterator):
172167
if (

0 commit comments

Comments
 (0)