|
24 | 24 | from ....configuration_utils import ConfigMixin, flax_register_to_config |
25 | 25 | from ...modeling_flax_utils import FlaxModelMixin |
26 | 26 | from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero |
27 | | -from ...attention_flax import FlaxFluxAttention |
| 27 | +from ...attention_flax import FlaxFluxAttention, apply_rope |
28 | 28 | from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) |
29 | 29 | from .... import common_types |
30 | 30 | from ....common_types import BlockSizes |
@@ -131,7 +131,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): |
131 | 131 | # since this function returns image_rotary_emb and passes it between layers, |
132 | 132 | # we do not want to modify it |
133 | 133 | image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) |
134 | | - q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) |
| 134 | + q, k = apply_rope(q, k, image_rotary_emb_reordered) |
135 | 135 |
|
136 | 136 | q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) |
137 | 137 | k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) |
|
0 commit comments