@@ -332,33 +332,33 @@ def __call__(
332332 rngs : nnx .Rngs = None ,
333333 ):
334334 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
335- (self .adaln_scale_shift_table + temb ), 6 , axis = 1
335+ (self .adaln_scale_shift_table + temb . astype ( jnp . float32 ) ), 6 , axis = 1
336336 )
337337 hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
338338 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
339339
340340 # 1. Self-attention
341- norm_hidden_states = (self .norm1 (hidden_states ) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
341+ norm_hidden_states = (self .norm1 (hidden_states . astype ( jnp . float32 ) ) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
342342 attn_output = self .attn1 (
343343 hidden_states = norm_hidden_states ,
344344 encoder_hidden_states = norm_hidden_states ,
345345 rotary_emb = rotary_emb ,
346346 deterministic = deterministic ,
347347 rngs = rngs ,
348348 )
349- hidden_states = (hidden_states + attn_output * gate_msa ).astype (hidden_states .dtype )
349+ hidden_states = (hidden_states . astype ( jnp . float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
350350
351351 # 2. Cross-attention
352- norm_hidden_states = self .norm2 (hidden_states )
352+ norm_hidden_states = self .norm2 (hidden_states . astype ( jnp . float32 )). astype ( hidden_states . dtype )
353353 attn_output = self .attn2 (
354354 hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
355355 )
356356 hidden_states = hidden_states + attn_output
357357
358358 # 3. Feed-forward
359- norm_hidden_states = (self .norm3 (hidden_states ) * (1 + c_scale_msa ) + c_shift_msa ).astype (hidden_states .dtype )
359+ norm_hidden_states = (self .norm3 (hidden_states . astype ( jnp . float32 ) ) * (1 + c_scale_msa ) + c_shift_msa ).astype (hidden_states .dtype )
360360 ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
361- hidden_states = (hidden_states + ff_output * c_gate_msa ).astype (hidden_states .dtype )
361+ hidden_states = (hidden_states . astype ( jnp . float32 ) + ff_output . astype ( jnp . float32 ) * c_gate_msa ).astype (hidden_states .dtype )
362362 return hidden_states
363363
364364
@@ -526,7 +526,7 @@ def scan_fn(carry, block):
526526
527527 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
528528
529- hidden_states = (self .norm_out (hidden_states ) * (1 + scale ) + shift ).astype (hidden_states .dtype )
529+ hidden_states = (self .norm_out (hidden_states . astype ( jnp . float32 ) ) * (1 + scale ) + shift ).astype (hidden_states .dtype )
530530 hidden_states = self .proj_out (hidden_states )
531531
532532 hidden_states = hidden_states .reshape (
0 commit comments