diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 958cff916..6638e0f8f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -359,12 +359,11 @@ def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True): state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape() - unboxed_sharded_abstract_state = unbox_logicallypartioned_trainstate(abstract_sharded_state) # Initialization with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) - return unboxed_sharded_abstract_state, state_mesh_annotations, state_mesh_shardings + return abstract_sharded_state, state_mesh_annotations, state_mesh_shardings def setup_initial_state(