@@ -83,18 +83,6 @@ def start_training(self):
8383 # create train states
8484 train_states = {}
8585 state_shardings = {}
86- unet_state , unet_state_mesh_shardings , unet_learning_rate_scheduler = self .create_unet_state (
87- # ambiguous here, but if self.params.get("unet") doesn't exist
88- # Then its 1 of 2 scenarios:
89- # 1. unet state will be loaded directly from orbax
90- # 2. a new unet is being trained from scratch.
91- pipeline = pipeline ,
92- params = params ,
93- checkpoint_item_name = "unet_state" ,
94- is_training = True ,
95- )
96- train_states ["unet_state" ] = unet_state
97- state_shardings ["unet_state_shardings" ] = unet_state_mesh_shardings
9886 vae_state , vae_state_mesh_shardings = self .create_vae_state (
9987 pipeline = pipeline , params = params , checkpoint_item_name = "vae_state" , is_training = False
10088 )
@@ -131,6 +119,19 @@ def start_training(self):
131119 if self .config .dataset_type == "grain" :
132120 data_iterator = self .restore_data_iterator_state (data_iterator )
133121
122+ unet_state , unet_state_mesh_shardings , unet_learning_rate_scheduler = self .create_unet_state (
123+ # ambiguous here, but if self.params.get("unet") doesn't exist
124+ # Then its 1 of 2 scenarios:
125+ # 1. unet state will be loaded directly from orbax
126+ # 2. a new unet is being trained from scratch.
127+ pipeline = pipeline ,
128+ params = params ,
129+ checkpoint_item_name = "unet_state" ,
130+ is_training = True ,
131+ )
132+ train_states ["unet_state" ] = unet_state
133+ state_shardings ["unet_state_shardings" ] = unet_state_mesh_shardings
134+
134135 data_shardings = self .get_data_shardings ()
135136 # Compile train_step
136137 p_train_step = self .compile_train_step (pipeline , params , train_states , state_shardings , data_shardings )
0 commit comments