@@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module):
2525 only_cross_attention : bool = False
2626 double_self_attention : bool = False
2727 upcast_attention : bool = False
28- # 'single_scale_shift' or 'single_scale'
29- adaptive_norm : str = "single_scale_shift"
28+ adaptive_norm : str = "single_scale_shift" # 'single_scale_shift' or 'single_scale'
3029 standardization_norm : str = "layer_norm" # 'layer_norm' or 'rms_norm'
3130 norm_elementwise_affine : bool = True
3231 norm_eps : float = 1e-5
3332 attention_type : str = "default"
3433 caption_channels : int = None
35- # if True uses the TPU attention offload ('flash attention')
36- use_tpu_flash_attention : bool = True
34+ use_tpu_flash_attention : bool = True # if True uses the TPU attention offload ('flash attention')
3735 qk_norm : Optional [str ] = None
3836 positional_embedding_type : str = "rope"
3937 positional_embedding_theta : Optional [float ] = None
@@ -98,7 +96,7 @@ def scale_shift_table_init(key):
9896 self .transformer_blocks = RepeatableLayer (
9997 RemattedBasicTransformerBlock ,
10098 num_layers = self .num_layers ,
101- module_init_kwargs = dict (
99+ module_init_kwargs = dict ( # noqa C408
102100 dim = self .inner_dim ,
103101 num_attention_heads = self .num_attention_heads ,
104102 attention_head_dim = self .attention_head_dim ,
@@ -139,46 +137,30 @@ def scale_shift_table_init(key):
139137 matmul_precision = self .matmul_precision ,
140138 )
141139
142- def init_weights (self , key , batch_size , text_tokens , num_tokens , features , eval_only = True ):
143-
144- # bookkeeping, for convenient changes later
145- latents_shape = (batch_size , num_tokens , features )
146- fractional_cords_shape = (batch_size , 3 , num_tokens )
147- prompt_embeds_shape = (batch_size , text_tokens , features )
148- noise_cond_shape = (batch_size , 1 )
149- latents_dtype = jnp .bfloat16
150- fractional_coords_dtype = jnp .bfloat16
151- prompt_embeds_dtype = jnp .bfloat16
152- noise_cond_dtype = jnp .bfloat16
153-
154- # initialize to random
155- key , split_key = jax .random .split (key )
156- prompt_embeds = jax .random .normal (split_key , shape = prompt_embeds_shape , dtype = latents_dtype )
157- key , split_key = jax .random .split (key )
158- fractional_coords = jax .random .normal (split_key , shape = fractional_cords_shape , dtype = fractional_coords_dtype )
159- key , split_key = jax .random .split (key )
160- latents = jax .random .normal (split_key , shape = latents_shape , dtype = prompt_embeds_dtype )
161- key , split_key = jax .random .split (key )
162- noise_cond = jax .random .normal (split_key , shape = noise_cond_shape , dtype = noise_cond_dtype )
163-
164- key , split_key = jax .random .split (key )
140+ def init_weights (self , in_channels , key , caption_channels , eval_only = True ):
141+ example_inputs = {}
142+ batch_size , num_tokens = 4 , 256
143+ input_shapes = {
144+ "hidden_states" : (batch_size , num_tokens , in_channels ),
145+ "indices_grid" : (batch_size , 3 , num_tokens ),
146+ "encoder_hidden_states" : (batch_size , 128 , caption_channels ),
147+ "timestep" : (batch_size , 256 ),
148+ "segment_ids" : (batch_size , 256 ),
149+ "encoder_attention_segment_ids" : (batch_size , 128 ),
150+ }
151+ for name , shape in input_shapes .items ():
152+ example_inputs [name ] = jnp .ones (
153+ shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
154+ )
155+
165156 if eval_only :
166157 return jax .eval_shape (
167158 self .init ,
168- rngs = {"params" : split_key },
169- hidden_states = latents ,
170- indices_grid = fractional_coords ,
171- encoder_hidden_states = prompt_embeds ,
172- timestep = noise_cond ,
159+ key ,
160+ ** example_inputs ,
173161 )["params" ]
174162 else :
175- return self .init (
176- rngs = {"params" : split_key },
177- hidden_states = latents ,
178- indices_grid = fractional_coords ,
179- encoder_hidden_states = prompt_embeds ,
180- timestep = noise_cond ,
181- )["params" ]
163+ return self .init (key , ** example_inputs )["params" ]
182164
183165 def __call__ (
184166 self ,
@@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array:
271253 @nn .compact
272254 def __call__ (self , indices_grid : jax .Array ) -> Tuple [jax .Array , jax .Array ]:
273255 source_dtype = indices_grid .dtype
274- # We need full precision in the freqs_cis computation.
275- dtype = jnp .float32
256+ dtype = jnp .float32 # We need full precision in the freqs_cis computation.
276257 dim = self .inner_dim
277258 theta = self .positional_embedding_theta
278259
@@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]:
294275 indices = indices * jnp .pi / 2
295276
296277 freqs = (indices * (jnp .expand_dims (fractional_positions , axis = - 1 ) * 2 - 1 )).swapaxes (- 1 , - 2 )
297- # Flatten along axis 2
298- freqs = freqs .reshape (freqs .shape [0 ], freqs .shape [1 ], - 1 )
278+ freqs = freqs .reshape (freqs .shape [0 ], freqs .shape [1 ], - 1 ) # Flatten along axis 2
299279
300280 cos_freq = jnp .cos (freqs ).repeat (2 , axis = - 1 )
301281 sin_freq = jnp .sin (freqs ).repeat (2 , axis = - 1 )
0 commit comments