|
32 | 32 | max_logging, |
33 | 33 | ) |
34 | 34 |
|
35 | | -from maxdiffusion.transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection) |
| 35 | +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection) |
36 | 36 |
|
37 | 37 | from maxdiffusion.checkpointing.checkpointing_utils import ( |
38 | 38 | create_orbax_checkpoint_manager, |
@@ -88,11 +88,14 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training) |
88 | 88 | config=self.config, |
89 | 89 | mesh=self.mesh, |
90 | 90 | weights_init_fn=weights_init_fn, |
91 | | - model_params=None if self.config.train_new_unet else params.get("unet", None), |
| 91 | + model_params=None, |
92 | 92 | checkpoint_manager=self.checkpoint_manager, |
93 | 93 | checkpoint_item=checkpoint_item_name, |
94 | 94 | training=is_training, |
95 | 95 | ) |
| 96 | + if not self.config.train_new_unet: |
| 97 | + unet_state = unet_state.replace(params=params.get("unet", None)) |
| 98 | + unet_state = jax.device_put(unet_state, state_mesh_shardings) |
96 | 99 | return unet_state, state_mesh_shardings, learning_rate_scheduler |
97 | 100 |
|
98 | 101 | def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): |
@@ -150,17 +153,20 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is |
150 | 153 | input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length), |
151 | 154 | ) |
152 | 155 |
|
153 | | - return max_utils.setup_initial_state( |
| 156 | + state, state_mesh_shardings = max_utils.setup_initial_state( |
154 | 157 | model=pipeline.text_encoder_2, |
155 | 158 | tx=tx, |
156 | 159 | config=self.config, |
157 | 160 | mesh=self.mesh, |
158 | 161 | weights_init_fn=weights_init_fn, |
159 | | - model_params=params.get("text_encoder_2", None), |
| 162 | + model_params=None, |
160 | 163 | checkpoint_manager=self.checkpoint_manager, |
161 | 164 | checkpoint_item=checkpoint_item_name, |
162 | 165 | training=is_training, |
163 | 166 | ) |
| 167 | + state = state.replace(params=params.get("text_encoder_2", None)) |
| 168 | + state = jax.device_put(state, state_mesh_shardings) |
| 169 | + return state, state_mesh_shardings |
164 | 170 |
|
165 | 171 | def restore_data_iterator_state(self, data_iterator): |
166 | 172 | if ( |
@@ -302,15 +308,16 @@ def load_checkpoint(self, step=None, scheduler_class=None): |
302 | 308 | tokenizer_path = os.path.join(tokenizer_path, "tokenizer") |
303 | 309 | tokenizer_path = max_utils.download_blobs(tokenizer_path, "/tmp") |
304 | 310 | tokenizer = CLIPTokenizer.from_pretrained( |
305 | | - tokenizer_path, subfolder="tokenizer", dtype=self.config.activations_dtype, weights_dtype=self.config.weights_dtype |
| 311 | + tokenizer_path, |
| 312 | + subfolder="tokenizer", |
| 313 | + dtype=self.config.activations_dtype, |
306 | 314 | ) |
307 | 315 |
|
308 | 316 | te_pretrained_config = CLIPTextConfig(**model_configs[0]["text_encoder_config"]) |
309 | 317 | text_encoder = FlaxCLIPTextModel( |
310 | 318 | te_pretrained_config, |
311 | 319 | seed=self.config.seed, |
312 | 320 | dtype=self.config.activations_dtype, |
313 | | - weights_dtype=self.config.weights_dtype, |
314 | 321 | _do_init=False, |
315 | 322 | ) |
316 | 323 |
|
|
0 commit comments