@@ -247,9 +247,19 @@ def __init__(
247247
248248 # Initialize video processor
249249 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_spatial_compression_ratio )
250-
251250 self .tokenizer_max_length = getattr (self .tokenizer , "model_max_length" , 1024 )
252251
252+ @staticmethod
253+ def _init_dummy_shape (node ):
254+ if isinstance (node , jax .ShapeDtypeStruct ):
255+ if jax .dtypes .issubdtype (node .dtype , jax .dtypes .prng_key ):
256+ dummy_key = jax .random .key (0 )
257+ if node .shape == ():
258+ return dummy_key
259+ return jax .random .split (dummy_key , node .shape [0 ])
260+ return jnp .zeros (node .shape , dtype = node .dtype )
261+ return node
262+
253263 @classmethod
254264 def load_tokenizer (cls , config : HyperParameters ):
255265 max_logging .log ("Loading Gemma Tokenizer..." )
@@ -286,7 +296,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
286296
287297 p_model_factory = partial (create_model , config = config )
288298 connectors = nnx .eval_shape (p_model_factory , rngs = rngs )
289- graphdef , state = nnx .split (connectors , nnx .Param )
299+ graphdef , state , rest_of_state = nnx .split (connectors , nnx .Param , ...)
300+ rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
290301
291302 logical_state_spec = nnx .get_partition_spec (state )
292303 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -307,7 +318,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
307318 state [path ].value = jax .device_put (val )
308319
309320 state = nnx .from_flat_state (state )
310- connectors = nnx .merge (graphdef , state )
321+ connectors = nnx .merge (graphdef , state , rest_of_state )
311322 return connectors
312323
313324 @classmethod
@@ -326,7 +337,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
326337
327338 p_model_factory = partial (create_model , config = config )
328339 vae = nnx .eval_shape (p_model_factory , rngs = rngs )
329- graphdef , state = nnx .split (vae , nnx .Param )
340+ graphdef , state , rest_of_state = nnx .split (vae , nnx .Param , ...)
341+ rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
330342
331343 logical_state_spec = nnx .get_partition_spec (state )
332344 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -353,7 +365,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
353365 state [path ].value = jax .device_put (val )
354366
355367 state = nnx .from_flat_state (state )
356- vae = nnx .merge (graphdef , state )
368+ vae = nnx .merge (graphdef , state , rest_of_state )
357369 return vae
358370
359371 @classmethod
@@ -372,7 +384,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
372384
373385 p_model_factory = partial (create_model , config = config )
374386 audio_vae = nnx .eval_shape (p_model_factory , rngs = rngs )
375- graphdef , state = nnx .split (audio_vae , nnx .Param )
387+ graphdef , state , rest_of_state = nnx .split (audio_vae , nnx .Param , ...)
388+ rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
376389
377390 logical_state_spec = nnx .get_partition_spec (state )
378391 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -399,7 +412,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
399412 state [path ].value = jax .device_put (val )
400413
401414 state = nnx .from_flat_state (state )
402- audio_vae = nnx .merge (graphdef , state )
415+ audio_vae = nnx .merge (graphdef , state , rest_of_state )
403416 return audio_vae
404417
405418 @classmethod
@@ -439,7 +452,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
439452
440453 p_model_factory = partial (create_model , config = config )
441454 vocoder = nnx .eval_shape (p_model_factory , rngs = rngs )
442- graphdef , state = nnx .split (vocoder , nnx .Param )
455+ graphdef , state , rest_of_state = nnx .split (vocoder , nnx .Param , ...)
456+ rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
443457
444458 logical_state_spec = nnx .get_partition_spec (state )
445459 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -460,7 +474,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
460474 state [path ].value = jax .device_put (val )
461475
462476 state = nnx .from_flat_state (state )
463- vocoder = nnx .merge (graphdef , state )
477+ vocoder = nnx .merge (graphdef , state , rest_of_state )
464478 return vocoder
465479
466480 @classmethod
0 commit comments