From 18cc9d4cd86559af9f7f07b4eb59607731b76617 Mon Sep 17 00:00:00 2001 From: Kunjan patel Date: Mon, 19 May 2025 08:18:43 +0000 Subject: [PATCH] Rope moved out of attention class to be method Signed-off-by: Kunjan patel --- .../models/flux/transformers/transformer_flux_flax.py | 4 ++-- src/maxdiffusion/train_utils.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 96207468e..97a31ebe9 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -24,7 +24,7 @@ from ....configuration_utils import ConfigMixin, flax_register_to_config from ...modeling_flax_utils import FlaxModelMixin from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero -from ...attention_flax import FlaxFluxAttention +from ...attention_flax import FlaxFluxAttention, apply_rope from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) from .... import common_types from ....common_types import BlockSizes @@ -131,7 +131,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): # since this function returns image_rotary_emb and passes it between layers, # we do not want to modify it image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) + q, k = apply_rope(q, k, image_rotary_emb_reordered) q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 1337f2329..e5a34a763 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -22,8 +22,7 @@ def get_first_step(state): - with jax.spmd_mode("allow_all"): - return int(state.step) + return int(state.step) def load_next_batch(train_iter, example_batch, config):