Skip to content

Commit 303e82a

Browse files
jfacevedo-googleksikiric
authored andcommitted
fix sdxl generate smoke tests.
1 parent 18250c5 commit 303e82a

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3333
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
34-
from max_utils import (
34+
from maxdiffusion.max_utils import (
3535
device_put_replicated,
3636
get_memory_allocations,
3737
create_device_mesh,
@@ -52,9 +52,6 @@ def unpack(x: Array, height: int, width: int) -> Array:
5252
)
5353

5454

55-
from einops import rearrange
56-
57-
5855
def vae_decode(latents, vae, state, config):
5956
img = unpack(x=latents, height=config.resolution, width=config.resolution)
6057
img = img / vae.config.scaling_factor + vae.config.shift_factor

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
from flax.linen import partitioning as nn_partitioning
4747
from flax.training import train_state
4848
from jax.experimental import mesh_utils
49-
from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel)
49+
from transformers import (
50+
FlaxCLIPTextModel,
51+
FlaxCLIPTextPreTrainedModel
52+
)
5053
from flax import struct
5154
from typing import (
5255
Callable,

0 commit comments

Comments
 (0)