@@ -144,10 +144,10 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
144144 hidden_states = self .linear2 (attn_mlp )
145145 hidden_states = gate * hidden_states
146146 hidden_states = residual + hidden_states
147- if hidden_states .dtype == jnp .float16 :
147+ if hidden_states .dtype == jnp .float16 or hidden_states . dtype == jnp . bfloat16 :
148148 hidden_states = jnp .clip (hidden_states , - 65504 , 65504 )
149149
150- return hidden_states
150+ return hidden_states , temb , image_rotary_emb
151151
152152
153153class FluxTransformerBlock (nn .Module ):
@@ -294,9 +294,9 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
294294
295295 context_ff_output = self .txt_mlp (norm_encoder_hidden_states )
296296 encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297- if encoder_hidden_states .dtype == jnp .float16 :
297+ if encoder_hidden_states .dtype == jnp .float16 or encoder_hidden_states . dtype == jnp . bfloat16 :
298298 encoder_hidden_states = encoder_hidden_states .clip (- 65504 , 65504 )
299- return hidden_states , encoder_hidden_states
299+ return hidden_states , encoder_hidden_states , temb , image_rotary_emb
300300
301301
302302@flax_register_to_config
@@ -504,7 +504,7 @@ def __call__(
504504 image_rotary_emb = nn .with_logical_constraint (image_rotary_emb , ("activation_batch" , "activation_embed" ))
505505
506506 for double_block in self .double_blocks :
507- hidden_states , encoder_hidden_states = double_block (
507+ hidden_states , encoder_hidden_states , temb , image_rotary_emb = double_block (
508508 hidden_states = hidden_states ,
509509 encoder_hidden_states = encoder_hidden_states ,
510510 temb = temb ,
@@ -513,7 +513,7 @@ def __call__(
513513 hidden_states = jnp .concatenate ([encoder_hidden_states , hidden_states ], axis = 1 )
514514 hidden_states = nn .with_logical_constraint (hidden_states , ("activation_batch" , "activation_length" , "activation_embed" ))
515515 for single_block in self .single_blocks :
516- hidden_states = single_block (
516+ hidden_states , temb , image_rotary_emb = single_block (
517517 hidden_states = hidden_states , temb = temb , image_rotary_emb = image_rotary_emb
518518 )
519519 hidden_states = hidden_states [:, encoder_hidden_states .shape [1 ] :, ...]
0 commit comments