diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 96207468e..97a31ebe9 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -24,7 +24,7 @@ from ....configuration_utils import ConfigMixin, flax_register_to_config from ...modeling_flax_utils import FlaxModelMixin from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero -from ...attention_flax import FlaxFluxAttention +from ...attention_flax import FlaxFluxAttention, apply_rope from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) from .... import common_types from ....common_types import BlockSizes @@ -131,7 +131,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): # since this function returns image_rotary_emb and passes it between layers, # we do not want to modify it image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) + q, k = apply_rope(q, k, image_rotary_emb_reordered) q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 1337f2329..e5a34a763 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -22,8 +22,7 @@ def get_first_step(state): - with jax.spmd_mode("allow_all"): - return int(state.step) + return int(state.step) def load_next_batch(train_iter, example_batch, config):