@@ -286,7 +286,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
286286
287287 p_model_factory = partial (create_model , config = config )
288288 connectors = nnx .eval_shape (p_model_factory , rngs = rngs )
289- graphdef , state , rest_of_state = nnx .split (connectors , nnx .Param , ... )
289+ graphdef , state = nnx .split (connectors , nnx .Param )
290290
291291 logical_state_spec = nnx .get_partition_spec (state )
292292 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -307,7 +307,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
307307 state [path ].value = jax .device_put (val )
308308
309309 state = nnx .from_flat_state (state )
310- connectors = nnx .merge (graphdef , state , rest_of_state )
310+ connectors = nnx .merge (graphdef , state )
311311 return connectors
312312
313313 @classmethod
@@ -326,7 +326,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
326326
327327 p_model_factory = partial (create_model , config = config )
328328 vae = nnx .eval_shape (p_model_factory , rngs = rngs )
329- graphdef , state , rest_of_state = nnx .split (vae , nnx .Param , ... )
329+ graphdef , state = nnx .split (vae , nnx .Param )
330330
331331 logical_state_spec = nnx .get_partition_spec (state )
332332 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -353,7 +353,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
353353 state [path ].value = jax .device_put (val )
354354
355355 state = nnx .from_flat_state (state )
356- vae = nnx .merge (graphdef , state , rest_of_state )
356+ vae = nnx .merge (graphdef , state )
357357 return vae
358358
359359 @classmethod
@@ -372,7 +372,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
372372
373373 p_model_factory = partial (create_model , config = config )
374374 audio_vae = nnx .eval_shape (p_model_factory , rngs = rngs )
375- graphdef , state , rest_of_state = nnx .split (audio_vae , nnx .Param , ... )
375+ graphdef , state = nnx .split (audio_vae , nnx .Param )
376376
377377 logical_state_spec = nnx .get_partition_spec (state )
378378 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -399,7 +399,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
399399 state [path ].value = jax .device_put (val )
400400
401401 state = nnx .from_flat_state (state )
402- audio_vae = nnx .merge (graphdef , state , rest_of_state )
402+ audio_vae = nnx .merge (graphdef , state )
403403 return audio_vae
404404
405405 @classmethod
@@ -439,7 +439,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
439439
440440 p_model_factory = partial (create_model , config = config )
441441 vocoder = nnx .eval_shape (p_model_factory , rngs = rngs )
442- graphdef , state , rest_of_state = nnx .split (vocoder , nnx .Param , ... )
442+ graphdef , state = nnx .split (vocoder , nnx .Param )
443443
444444 logical_state_spec = nnx .get_partition_spec (state )
445445 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
@@ -460,7 +460,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
460460 state [path ].value = jax .device_put (val )
461461
462462 state = nnx .from_flat_state (state )
463- vocoder = nnx .merge (graphdef , state , rest_of_state )
463+ vocoder = nnx .merge (graphdef , state )
464464 return vocoder
465465
466466 @classmethod
0 commit comments