2828 lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929 usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030 Defaults to False.
31- --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32- If unspecified, we use the default Hugging Face repository ID
31+ --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+ If unspecified, we use the default Hugging Face repository ID
3333 (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34- This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
34+ This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3535
3636Environment Variables:
3737 HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
5151
5252 To convert a 70B model with minimal RAM usage:
5353
54- /usr/bin/time -v python src/MaxText/utils/ckpt_conversion /to_maxtext.py \
55- maxtext/configs/base.yml model_name="meta-llama/Llama-3 .1-70B " \
54+ /usr/bin/time -v python src/MaxText/checkpoint_conversion /to_maxtext.py \
55+ maxtext/configs/base.yml model_name="llama3 .1-70b " \
5656 base_output_directory="gs://my-bucket/maxtext-checkpoints" \
5757 hf_access_token=$HF_TOKEN hardware=cpu skip_jax_distributed_system=True \
5858 --lazy_load_tensors=True
7171from huggingface_hub import hf_hub_download , list_repo_files
7272import jax
7373from MaxText import pyconfig
74+ from MaxText .common_types import MODEL_MODE_TRAIN
7475from maxtext .checkpoint_conversion .standalone_scripts .llama_or_mistral_ckpt import save_weights_to_checkpoint
7576from maxtext .checkpoint_conversion .utils .param_mapping import HOOK_FNS , PARAM_MAPPING
7677from maxtext .checkpoint_conversion .utils .utils import HF_IDS , MemoryMonitorTqdm , apply_hook_fns , get_hf_model , print_peak_memory , print_ram_usage , validate_and_filter_param_map_keys
77- from MaxText .common_types import MODEL_MODE_TRAIN
7878from maxtext .inference .inference_utils import str2bool
7979from maxtext .layers import quantizations
8080from maxtext .models import models
81- from maxtext .utils import max_logging
82- from maxtext .utils import max_utils
83- from maxtext .utils import maxtext_utils
81+ from maxtext .utils import max_logging , max_utils , maxtext_utils
8482import numpy as np
8583from orbax .checkpoint import type_handlers
8684from safetensors import safe_open
@@ -543,7 +541,13 @@ def _slicing_loader(base_loader, slice_idx):
543541 )
544542
545543
546- def main (args : Sequence [str ], test_args : Sequence [str ]) -> None :
544+ def main (
545+ args : Sequence [str ],
546+ hf_model_path : str | None = None ,
547+ revision : str | None = None ,
548+ lazy_load_tensors : bool = False ,
549+ simulated_cpu_devices_count : int = 16 ,
550+ ) -> None :
547551 overall_start = time .time ()
548552 # Check if the user is using an Instruct version. If so, use the base model architecture
549553 for i , arg in enumerate (args ):
@@ -560,10 +564,7 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
560564 if model_name_original not in HF_IDS :
561565 raise ValueError (f"Unsupported model name: { model_name_original } . Supported models are: { list (HF_IDS .keys ())} " )
562566
563- if not test_args .hf_model_path :
564- model_id = HF_IDS [model_name_original ]
565- else :
566- model_id = test_args .hf_model_path
567+ model_id = hf_model_path or HF_IDS [model_name_original ]
567568
568569 # Initialize maxtext config
569570 config = pyconfig .initialize (args )
@@ -575,17 +576,15 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
575576 output_directory = config .base_output_directory
576577
577578 hf_token = config .hf_access_token
578- revision = test_args .revision
579- use_lazy_load = test_args .lazy_load_tensors
580579
581- if use_lazy_load and config .use_multimodal :
580+ if lazy_load_tensors and config .use_multimodal :
582581 raise ValueError ("lazy loading of HF tensors is not supported for multimodal models yet." )
583582
584583 hf_state_dict_numpy = None
585584 hf_loader = None
586585
587586 # Define the appropriate tensor getter based on mode
588- if use_lazy_load :
587+ if lazy_load_tensors :
589588 max_logging .log (f"Lazy loading ENABLED. Initializing LazyHFLoader for: { model_id } ..." )
590589 hf_loader = LazyHFLoader (model_id , hf_token , revision = revision )
591590 hf_config_obj = AutoConfig .from_pretrained (model_id , token = hf_token , revision = revision )
@@ -636,7 +635,7 @@ def _eager_getter(key):
636635 for mt_param_key_or_keys in MemoryMonitorTqdm (
637636 filtered_map_keys , desc = "Transforming weights" , unit = "param" , leave = True , dynamic_ncols = True
638637 ):
639- if not use_lazy_load :
638+ if not lazy_load_tensors :
640639 max_logging .log (f"maxtext param: { mt_param_key_or_keys } " )
641640
642641 hf_source_keys_or_key = param_map_mt_to_hf .get (mt_param_key_or_keys )
@@ -663,7 +662,7 @@ def _eager_getter(key):
663662 mt_param_key_or_keys ,
664663 final_mt_weights ,
665664 config ,
666- use_lazy_load ,
665+ lazy_load_tensors ,
667666 )
668667
669668 del hf_state_dict_numpy
@@ -676,7 +675,7 @@ def _eager_getter(key):
676675 del final_mt_weights , abstract_params_treedef
677676
678677 print_ram_usage ("Before saving" )
679- if use_lazy_load :
678+ if lazy_load_tensors :
680679 max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
681680 else :
682681 max_logging .log ("Starting checkpoint save..." )
@@ -687,7 +686,7 @@ def _eager_getter(key):
687686 save_weights_to_checkpoint (
688687 output_directory ,
689688 jax_weights ,
690- test_args . simulated_cpu_devices_count ,
689+ simulated_cpu_devices_count ,
691690 config .checkpoint_storage_use_ocdbt ,
692691 config .checkpoint_storage_use_zarr3 ,
693692 )
@@ -743,4 +742,10 @@ def _eager_getter(key):
743742 # Set jax environment
744743 jax .config .update ("jax_platforms" , "cpu" )
745744 os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ local_args .simulated_cpu_devices_count } "
746- main (model_args , local_args )
745+ main (
746+ args = model_args ,
747+ hf_model_path = local_args .hf_model_path ,
748+ revision = local_args .revision ,
749+ lazy_load_tensors = local_args .lazy_load_tensors ,
750+ simulated_cpu_devices_count = local_args .simulated_cpu_devices_count ,
751+ )
0 commit comments