3333
3434STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
36+ FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
3637
3738
3839def create_orbax_checkpoint_manager (
@@ -56,17 +57,20 @@ def create_orbax_checkpoint_manager(
5657 max_logging .log (f"checkpoint dir: { checkpoint_dir } " )
5758 p = epath .Path (checkpoint_dir )
5859
59- item_names = (
60- "unet_config" ,
61- "vae_config" ,
62- "text_encoder_config" ,
63- "scheduler_config" ,
64- "unet_state" ,
65- "vae_state" ,
66- "text_encoder_state" ,
67- "tokenizer_config" ,
68- )
69- if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT :
60+ if checkpoint_type == FLUX_CHECKPOINT :
61+ item_names = ("flux_state" , "flux_config" , "vae_state" , "vae_config" , "scheduler" , "scheduler_config" )
62+ else :
63+ item_names = (
64+ "unet_config" ,
65+ "vae_config" ,
66+ "text_encoder_config" ,
67+ "scheduler_config" ,
68+ "unet_state" ,
69+ "vae_state" ,
70+ "text_encoder_state" ,
71+ "tokenizer_config" ,
72+ )
73+ if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT :
7074 item_names += (
7175 "text_encoder_2_state" ,
7276 "text_encoder_2_config" ,
@@ -117,7 +121,7 @@ def load_stable_diffusion_configs(
117121 "tokenizer_config" : orbax .checkpoint .args .JsonRestore (),
118122 }
119123
120- if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT :
124+ if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT :
121125 restore_args ["text_encoder_2_config" ] = orbax .checkpoint .args .JsonRestore ()
122126
123127 return (checkpoint_manager .restore (step , args = orbax .checkpoint .args .Composite (** restore_args )), None )
@@ -139,6 +143,8 @@ def load_params_from_path(
139143
140144 ckpt_path = os .path .join (config .checkpoint_dir , str (step ), checkpoint_item )
141145 ckpt_path = epath .Path (ckpt_path )
146+ if not ckpt_path .as_uri ().startswith ("gs://" ):
147+ ckpt_path = os .path .abspath (ckpt_path )
142148
143149 restore_args = ocp .checkpoint_utils .construct_restore_args (unboxed_abstract_params )
144150 restored = ckptr .restore (
0 commit comments