Skip to content

Commit 37df8b9

Browse files
fix sdxl generate smoke tests.
1 parent b87443f commit 37df8b9

5 files changed

Lines changed: 23 additions & 10 deletions

File tree

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,14 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888
config=self.config,
8989
mesh=self.mesh,
9090
weights_init_fn=weights_init_fn,
91-
model_params=None if self.config.train_new_unet else params.get("unet", None),
91+
model_params=None,
9292
checkpoint_manager=self.checkpoint_manager,
9393
checkpoint_item=checkpoint_item_name,
9494
training=is_training,
9595
)
96+
if not self.config.train_new_unet:
97+
unet_state = unet_state.replace(params=params.get("unet", None))
98+
unet_state = jax.device_put(unet_state, state_mesh_shardings)
9699
return unet_state, state_mesh_shardings, learning_rate_scheduler
97100

98101
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -150,17 +153,20 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
150153
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
151154
)
152155

153-
return max_utils.setup_initial_state(
156+
state, state_mesh_shardings = max_utils.setup_initial_state(
154157
model=pipeline.text_encoder_2,
155158
tx=tx,
156159
config=self.config,
157160
mesh=self.mesh,
158161
weights_init_fn=weights_init_fn,
159-
model_params=params.get("text_encoder_2", None),
162+
model_params=None,
160163
checkpoint_manager=self.checkpoint_manager,
161164
checkpoint_item=checkpoint_item_name,
162165
training=is_training,
163166
)
167+
state = state.replace(params=params.get("text_encoder_2", None))
168+
state = jax.device_put(state, state_mesh_shardings)
169+
return state, state_mesh_shardings
164170

165171
def restore_data_iterator_state(self, data_iterator):
166172
if (

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3333
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
34-
from max_utils import (
34+
from maxdiffusion.max_utils import (
3535
device_put_replicated,
3636
get_memory_allocations,
3737
create_device_mesh,
@@ -52,9 +52,6 @@ def unpack(x: Array, height: int, width: int) -> Array:
5252
)
5353

5454

55-
from einops import rearrange
56-
57-
5855
def vae_decode(latents, vae, state, config):
5956
img = unpack(x=latents, height=config.resolution, width=config.resolution)
6057
img = img / vae.config.scaling_factor + vae.config.shift_factor

src/maxdiffusion/generate_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,11 @@ def run(config):
249249
config=config,
250250
mesh=checkpoint_loader.mesh,
251251
weights_init_fn=weights_init_fn,
252-
model_params=params.get("unet", None),
252+
model_params=None,
253253
training=False,
254254
)
255+
unet_state = unet_state.replace(params=params.get("unet", None))
256+
unet_state = jax.device_put(unet_state, unet_state_shardings)
255257

256258
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
257259
pipeline, params, checkpoint_item_name="vae_state", is_training=False

src/maxdiffusion/max_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
from flax.linen import partitioning as nn_partitioning
4747
from flax.training import train_state
4848
from jax.experimental import mesh_utils
49-
from jax.sharding import PositionalSharding
49+
from transformers import (
50+
FlaxCLIPTextModel,
51+
FlaxCLIPTextPreTrainedModel
52+
)
5053
from flax import struct
5154
from typing import (
5255
Callable,
@@ -315,7 +318,10 @@ def init_train_state(model, tx, weights_init_fn, params=None, training=True, eva
315318
Args: model_params, model, tx, training
316319
"""
317320
if not params:
318-
params = weights_init_fn(eval_only=eval_only)
321+
if isinstance(model, FlaxCLIPTextModel) or isinstance(model, FlaxCLIPTextPreTrainedModel):
322+
params = weights_init_fn()
323+
else:
324+
params = weights_init_fn(eval_only=eval_only)
319325
if training:
320326
state = train_state.TrainState.create(
321327
apply_fn=model.apply if hasattr(model, "apply") else model.__call__,

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def create_flax_params_from_pytorch_state(
163163
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
164164
flax_key_list = [*pt_tuple_key]
165165
flax_tensor = pt_tensor
166+
if "lora" in flax_key_list:
167+
flax_key_list[flax_key_list.index("lora")] = f"lora-{adapter_name}"
166168
else:
167169
flax_key_list = [*pt_tuple_key]
168170
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:

0 commit comments

Comments
 (0)