Skip to content

Commit 5453b3c

Browse files
committed
Fixed comments and rebased on main
1 parent 4600a72 commit 5453b3c

4 files changed

Lines changed: 8 additions & 12 deletions

File tree

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,8 @@ def __init__(self, config, checkpoint_type):
6161
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
6262
self.total_train_batch_size = self.config.total_train_batch_size
6363

64-
checkpoint_dir = os.path.abspath(self.config.checkpoint_dir)
65-
6664
self.checkpoint_manager = create_orbax_checkpoint_manager(
67-
checkpoint_dir,
65+
self.config.checkpoint_dir,
6866
enable_checkpointing=True,
6967
save_interval_steps=1,
7068
checkpoint_type=checkpoint_type,

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ flash_block_sizes: {}
7575
# GroupNorm groups
7676
norm_num_groups: 32
7777

78-
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
78+
# If train_new_flux, flux weights will be randomly initialized to train flux from scratch
7979
# else they will be loaded from pretrained_model_name_or_path
8080
train_new_flux: False
8181

@@ -223,8 +223,8 @@ skip_first_n_steps_for_profiler: 5
223223
profiler_steps: 10
224224

225225
# Generation parameters
226-
prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
227-
prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
226+
prompt: "A magical castle in the middle of a forest, artistic drawing"
227+
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
228228
negative_prompt: "purple, red"
229229
do_classifier_free_guidance: True
230230
guidance_scale: 3.5

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ flash_block_sizes: {
8383
# GroupNorm groups
8484
norm_num_groups: 32
8585

86-
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
86+
# If train_new_flux, flux weights will be randomly initialized to train flux from scratch
8787
# else they will be loaded from pretrained_model_name_or_path
88-
train_new_unet: False
88+
train_new_flux: False
8989

9090
# train text_encoder - Currently not supported for SDXL
9191
train_text_encoder: False
@@ -123,7 +123,7 @@ diffusion_scheduler_config: {
123123
base_output_directory: ""
124124

125125
# Hardware
126-
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
126+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
127127

128128
# Parallelism
129129
mesh_axes: ['data', 'fsdp', 'tensor']
@@ -238,7 +238,7 @@ do_classifier_free_guidance: True
238238
guidance_scale: 0.0
239239
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
240240
guidance_rescale: 0.0
241-
num_inference_steps: 50
241+
num_inference_steps: 4
242242

243243
# SDXL Lightning parameters
244244
lightning_from_pt: True

src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import flax
1919
import jax.numpy as jnp
20-
import max_logging
2120

2221
from ..configuration_utils import ConfigMixin, register_to_config
2322
from .scheduling_utils_flax import (
@@ -170,7 +169,6 @@ def set_timesteps(
170169
timesteps = (jnp.arange(self.config.num_train_timesteps, 0, -step_ratio)).round()
171170
timesteps -= 1
172171
elif timestep_spacing == "flux":
173-
max_logging.log("Using flux timestep spacing")
174172
timesteps = jnp.linspace(1, 0, num_inference_steps + 1)
175173
else:
176174
raise ValueError(

0 commit comments

Comments
 (0)