1515"""
1616
1717from typing import Tuple , Optional , Dict , Union , Any
18+ import math
1819import jax
1920import jax .numpy as jnp
2021from flax import nnx
21- from .... import common_types , max_logging
22+ from .... import common_types
2223from ...modeling_flax_utils import FlaxModelMixin , get_activation
2324from ....configuration_utils import ConfigMixin , register_to_config
2425from ...embeddings_flax import (
@@ -447,7 +448,7 @@ def __init__(
447448 rngs = rngs ,
448449 dim = inner_dim ,
449450 ffn_dim = ffn_dim ,
450- num_attention_heads = num_attention_heads ,
451+ num_heads = num_attention_heads ,
451452 qk_norm = qk_norm ,
452453 cross_attn_norm = cross_attn_norm ,
453454 eps = eps ,
@@ -462,6 +463,20 @@ def __init__(
462463 blocks .append (block )
463464 self .blocks = blocks
464465
466+ self .norm_out = FP32LayerNorm (rngs = rngs , dim = inner_dim , eps = eps , elementwise_affine = False )
467+ self .proj_out = nnx .Linear (
468+ rngs = rngs ,
469+ in_features = inner_dim ,
470+ out_features = out_channels * math .prod (patch_size ),
471+ dtype = dtype ,
472+ param_dtype = weights_dtype ,
473+ precision = precision ,
474+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "mlp" ,)),
475+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
476+ )
477+ key = rngs .params ()
478+ self .scale_shift_table = nnx .Param (jax .random .normal (key , (1 , 2 , inner_dim )) / inner_dim ** 0.5 )
479+
465480 def __call__ (
466481 self ,
467482 hidden_states : jax .Array ,
@@ -492,7 +507,14 @@ def __call__(
492507
493508 for block in self .blocks :
494509 hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
495- breakpoint ()
510+
511+ shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
512+
513+ hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
514+ hidden_states = self .proj_out (hidden_states )
496515
516+ # TODO - can this reshape happen in a single command?
517+ hidden_states = hidden_states .reshape (batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p_h , p_w , - 1 )
518+ hidden_states = hidden_states .reshape (batch_size , num_frames , height , width , num_channels )
497519
498520 return hidden_states
0 commit comments