@@ -170,8 +170,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
170170 for path , val in flax .traverse_util .flatten_dict (params ).items ():
171171 if restored_checkpoint :
172172 path = path [:- 1 ]
173- sharding = logical_state_sharding [path ].value
174- state [path ].value = device_put_replicated (val , sharding )
173+ sharding = logical_state_sharding [path ].get_value ()
174+ state [path ].set_value ( device_put_replicated (val , sharding ) )
175175 state = nnx .from_flat_state (state )
176176
177177 transformer = nnx .merge (graphdef , state , rest_of_state )
@@ -351,10 +351,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
351351 for path , val in flax .traverse_util .flatten_dict (params ).items ():
352352 sharding = logical_state_sharding .get (path )
353353 if sharding is not None :
354- sharding = sharding .value
355- state [path ].value = device_put_replicated (val , sharding )
354+ sharding = sharding .get_value ()
355+ state [path ].set_value ( device_put_replicated (val , sharding ) )
356356 else :
357- state [path ].value = jax .device_put (val )
357+ state [path ].set_value ( jax .device_put (val ) )
358358
359359 state = nnx .from_flat_state (state )
360360 connectors = nnx .merge (graphdef , state , rest_of_state )
@@ -393,16 +393,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
393393 for path , val in flax .traverse_util .flatten_dict (params ).items ():
394394 sharding = logical_state_sharding .get (path )
395395 if sharding is not None :
396- sharding = sharding .value
396+ sharding = sharding .get_value ()
397397 try :
398398 replicate_vae = config .replicate_vae
399399 except ValueError :
400400 replicate_vae = False
401401 if replicate_vae :
402402 sharding = NamedSharding (mesh , P ())
403- state [path ].value = device_put_replicated (val , sharding )
403+ state [path ].set_value ( device_put_replicated (val , sharding ) )
404404 else :
405- state [path ].value = jax .device_put (val )
405+ state [path ].set_value ( jax .device_put (val ) )
406406
407407 state = nnx .from_flat_state (state )
408408 vae = nnx .merge (graphdef , state , rest_of_state )
@@ -441,16 +441,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
441441 for path , val in flax .traverse_util .flatten_dict (params ).items ():
442442 sharding = logical_state_sharding .get (path )
443443 if sharding is not None :
444- sharding = sharding .value
444+ sharding = sharding .get_value ()
445445 try :
446446 replicate_vae = config .replicate_vae
447447 except ValueError :
448448 replicate_vae = False
449449 if replicate_vae :
450450 sharding = NamedSharding (mesh , P ())
451- state [path ].value = device_put_replicated (val , sharding )
451+ state [path ].set_value ( device_put_replicated (val , sharding ) )
452452 else :
453- state [path ].value = jax .device_put (val )
453+ state [path ].set_value ( jax .device_put (val ) )
454454
455455 state = nnx .from_flat_state (state )
456456 audio_vae = nnx .merge (graphdef , state , rest_of_state )
@@ -510,10 +510,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
510510 for path , val in flax .traverse_util .flatten_dict (params ).items ():
511511 sharding = logical_state_sharding .get (path )
512512 if sharding is not None :
513- sharding = sharding .value
514- state [path ].value = device_put_replicated (val , sharding )
513+ sharding = sharding .get_value ()
514+ state [path ].set_value ( device_put_replicated (val , sharding ) )
515515 else :
516- state [path ].value = jax .device_put (val )
516+ state [path ].set_value ( jax .device_put (val ) )
517517
518518 state = nnx .from_flat_state (state )
519519 vocoder = nnx .merge (graphdef , state , rest_of_state )
0 commit comments