Skip to content

Commit 544714b

Browse files
committed
Fix transformers compatibility - handle FlaxCLIPTextModel import for v5.0+
1 parent ff537a5 commit 544714b

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,19 @@
2929
from chex import Array
3030
from einops import rearrange
3131
from flax.linen import partitioning as nn_partitioning
32-
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
32+
33+
try:
34+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
35+
except ImportError:
36+
# For transformers>=5.0, Flax models have different import paths
37+
from transformers import CLIPTokenizer, T5EncoderModel, AutoTokenizer
38+
39+
try:
40+
from transformers.models.clip.modeling_flax_clip import FlaxCLIPTextModel
41+
from transformers.models.t5.modeling_flax_t5 import FlaxT5EncoderModel
42+
except ImportError:
43+
FlaxCLIPTextModel = None
44+
FlaxT5EncoderModel = None
3345

3446
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3547
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel

0 commit comments

Comments
 (0)