Skip to content

Commit e0a538f

Browse files
merge main
2 parents b11767b + 3b4f4d5 commit e0a538f

18 files changed

Lines changed: 379 additions & 15 deletions

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ tensorflow-datasets>=4.9.6
2424
ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.2
27+
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece
33+
aqtp

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jaxlib>=0.4.30
1414
Jinja2
1515
opencv-python-headless==4.10.0.84
1616
optax>=0.2.3
17-
orbax-checkpoint==0.10.2
17+
orbax-checkpoint==0.10.3
1818
parameterized
1919
Pillow
2020
pyink

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,11 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888
config=self.config,
8989
mesh=self.mesh,
9090
weights_init_fn=weights_init_fn,
91-
model_params=None,
91+
model_params=None if self.config.train_new_unet else params.get("unet", None),
9292
checkpoint_manager=self.checkpoint_manager,
9393
checkpoint_item=checkpoint_item_name,
9494
training=is_training,
9595
)
96-
if not self.config.train_new_unet:
97-
unet_state = unet_state.replace(params=params.get("unet", None))
98-
unet_state = jax.device_put(unet_state, state_mesh_shardings)
9996
return unet_state, state_mesh_shardings, learning_rate_scheduler
10097

10198
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -153,20 +150,18 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
153150
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
154151
)
155152

156-
state, state_mesh_shardings = max_utils.setup_initial_state(
153+
# state, state_mesh_shardings =
154+
return max_utils.setup_initial_state(
157155
model=pipeline.text_encoder_2,
158156
tx=tx,
159157
config=self.config,
160158
mesh=self.mesh,
161159
weights_init_fn=weights_init_fn,
162-
model_params=None,
160+
model_params=params.get("text_encoder_2", None),
163161
checkpoint_manager=self.checkpoint_manager,
164162
checkpoint_item=checkpoint_item_name,
165163
training=is_training,
166164
)
167-
state = state.replace(params=params.get("text_encoder_2", None))
168-
state = jax.device_put(state, state_mesh_shardings)
169-
return state, state_mesh_shardings
170165

171166
def restore_data_iterator_state(self, data_iterator):
172167
if (

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,7 @@ prior_loss_weight: 1.0
216216
num_class_images: 100
217217
# If true, set dataset_save_location.
218218
cache_dreambooth_dataset: False
219+
quantization: ''
220+
# Shard the range finding operation for quantization. By default this is set to number of slices.
221+
quantization_local_shard_count: -1
222+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,7 @@ prior_loss_weight: 1.0
217217
num_class_images: 100
218218
# If true, set dataset_save_location.
219219
cache_dreambooth_dataset: False
220+
quantization: ''
221+
# Shard the range finding operation for quantization. By default this is set to number of slices.
222+
quantization_local_shard_count: -1
223+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,8 @@ prior_loss_weight: 1.0
231231
num_class_images: 100
232232
# If true, set dataset_save_location.
233233
cache_dreambooth_dataset: False
234+
235+
quantization: ''
236+
# Shard the range finding operation for quantization. By default this is set to number of slices.
237+
quantization_local_shard_count: -1
238+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,8 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
260260
controlnet_from_pt: True
261261
controlnet_conditioning_scale: 0.5
262262
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
263+
quantization: ''
264+
# Shard the range finding operation for quantization. By default this is set to number of slices.
265+
quantization_local_shard_count: -1
266+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
267+

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,7 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
268268
controlnet_from_pt: True
269269
controlnet_conditioning_scale: 0.5
270270
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
271+
quantization: ''
272+
# Shard the range finding operation for quantization. By default this is set to number of slices.
273+
quantization_local_shard_count: -1
274+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_xl.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,9 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
233233
controlnet_from_pt: True
234234
controlnet_conditioning_scale: 0.5
235235
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
236+
enable_mllog: False
237+
238+
quantization: ''
239+
# Shard the range finding operation for quantization. By default this is set to number of slices.
240+
quantization_local_shard_count: -1
241+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,7 @@ lora_config: {
185185
# }
186186

187187
enable_mllog: False
188+
quantization: ''
189+
# Shard the range finding operation for quantization. By default this is set to number of slices.
190+
quantization_local_shard_count: -1
191+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

0 commit comments

Comments
 (0)