Skip to content

Commit f33b16c

Browse files
Merge pull request #3197 from AI-Hypercomputer:igorts/to_maxtext
PiperOrigin-RevId: 874727435
2 parents d340aa1 + b074bdb commit f33b16c

2 files changed

Lines changed: 58 additions & 28 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030
Defaults to False.
31-
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32-
If unspecified, we use the default Hugging Face repository ID
31+
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+
If unspecified, we use the default Hugging Face repository ID
3333
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34-
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
34+
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3535
3636
Environment Variables:
3737
HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
@@ -51,8 +51,8 @@
5151
5252
To convert a 70B model with minimal RAM usage:
5353
54-
/usr/bin/time -v python src/MaxText/utils/ckpt_conversion/to_maxtext.py \
55-
maxtext/configs/base.yml model_name="meta-llama/Llama-3.1-70B" \
54+
/usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \
55+
maxtext/configs/base.yml model_name="llama3.1-70b" \
5656
base_output_directory="gs://my-bucket/maxtext-checkpoints" \
5757
hf_access_token=$HF_TOKEN hardware=cpu skip_jax_distributed_system=True \
5858
--lazy_load_tensors=True
@@ -71,16 +71,14 @@
7171
from huggingface_hub import hf_hub_download, list_repo_files
7272
import jax
7373
from MaxText import pyconfig
74+
from MaxText.common_types import MODEL_MODE_TRAIN
7475
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
7576
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
7677
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
77-
from MaxText.common_types import MODEL_MODE_TRAIN
7878
from maxtext.inference.inference_utils import str2bool
7979
from maxtext.layers import quantizations
8080
from maxtext.models import models
81-
from maxtext.utils import max_logging
82-
from maxtext.utils import max_utils
83-
from maxtext.utils import maxtext_utils
81+
from maxtext.utils import max_logging, max_utils, maxtext_utils
8482
import numpy as np
8583
from orbax.checkpoint import type_handlers
8684
from safetensors import safe_open
@@ -543,7 +541,13 @@ def _slicing_loader(base_loader, slice_idx):
543541
)
544542

545543

546-
def main(args: Sequence[str], test_args: Sequence[str]) -> None:
544+
def main(
545+
args: Sequence[str],
546+
hf_model_path: str | None = None,
547+
revision: str | None = None,
548+
lazy_load_tensors: bool = False,
549+
simulated_cpu_devices_count: int = 16,
550+
) -> None:
547551
overall_start = time.time()
548552
# Check if the user is using an Instruct version. If so, use the base model architecture
549553
for i, arg in enumerate(args):
@@ -560,10 +564,7 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
560564
if model_name_original not in HF_IDS:
561565
raise ValueError(f"Unsupported model name: {model_name_original}. Supported models are: {list(HF_IDS.keys())}")
562566

563-
if not test_args.hf_model_path:
564-
model_id = HF_IDS[model_name_original]
565-
else:
566-
model_id = test_args.hf_model_path
567+
model_id = hf_model_path or HF_IDS[model_name_original]
567568

568569
# Initialize maxtext config
569570
config = pyconfig.initialize(args)
@@ -575,17 +576,15 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
575576
output_directory = config.base_output_directory
576577

577578
hf_token = config.hf_access_token
578-
revision = test_args.revision
579-
use_lazy_load = test_args.lazy_load_tensors
580579

581-
if use_lazy_load and config.use_multimodal:
580+
if lazy_load_tensors and config.use_multimodal:
582581
raise ValueError("lazy loading of HF tensors is not supported for multimodal models yet.")
583582

584583
hf_state_dict_numpy = None
585584
hf_loader = None
586585

587586
# Define the appropriate tensor getter based on mode
588-
if use_lazy_load:
587+
if lazy_load_tensors:
589588
max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...")
590589
hf_loader = LazyHFLoader(model_id, hf_token, revision=revision)
591590
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision)
@@ -636,7 +635,7 @@ def _eager_getter(key):
636635
for mt_param_key_or_keys in MemoryMonitorTqdm(
637636
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
638637
):
639-
if not use_lazy_load:
638+
if not lazy_load_tensors:
640639
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")
641640

642641
hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys)
@@ -663,7 +662,7 @@ def _eager_getter(key):
663662
mt_param_key_or_keys,
664663
final_mt_weights,
665664
config,
666-
use_lazy_load,
665+
lazy_load_tensors,
667666
)
668667

669668
del hf_state_dict_numpy
@@ -676,7 +675,7 @@ def _eager_getter(key):
676675
del final_mt_weights, abstract_params_treedef
677676

678677
print_ram_usage("Before saving")
679-
if use_lazy_load:
678+
if lazy_load_tensors:
680679
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
681680
else:
682681
max_logging.log("Starting checkpoint save...")
@@ -687,7 +686,7 @@ def _eager_getter(key):
687686
save_weights_to_checkpoint(
688687
output_directory,
689688
jax_weights,
690-
test_args.simulated_cpu_devices_count,
689+
simulated_cpu_devices_count,
691690
config.checkpoint_storage_use_ocdbt,
692691
config.checkpoint_storage_use_zarr3,
693692
)
@@ -743,4 +742,10 @@ def _eager_getter(key):
743742
# Set jax environment
744743
jax.config.update("jax_platforms", "cpu")
745744
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
746-
main(model_args, local_args)
745+
main(
746+
args=model_args,
747+
hf_model_path=local_args.hf_model_path,
748+
revision=local_args.revision,
749+
lazy_load_tensors=local_args.lazy_load_tensors,
750+
simulated_cpu_devices_count=local_args.simulated_cpu_devices_count,
751+
)

src/maxtext/examples/demo_decoding.ipynb

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@
9292
"!uv pip install nest_asyncio\n",
9393
"\n",
9494
"# Install the PyTorch library\n",
95-
"!uv pip install torch"
95+
"!uv pip install torch\n",
96+
"\n",
97+
"# This is needed for the Hugging Face login to work\n",
98+
"!uv pip install ipywidgets"
9699
]
97100
},
98101
{
@@ -119,6 +122,22 @@
119122
"## Imports"
120123
]
121124
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"id": "915db240",
129+
"metadata": {},
130+
"outputs": [],
131+
"source": [
132+
"from absl import logging as absl_logging\n",
133+
"\n",
134+
"# This will make `max_logging.log(...)` outputs visible in the Notebook.\n",
135+
"# This statement needs to be invoked before we import jax, otherwise it won't work.\n",
136+
"absl_logging.set_verbosity(absl_logging.INFO)\n",
137+
"\n",
138+
"absl_logging.error(\"Test - you should see this message below\")"
139+
]
140+
},
122141
{
123142
"cell_type": "code",
124143
"execution_count": null,
@@ -136,18 +155,24 @@
136155
"import MaxText as mt\n",
137156
"from MaxText import common_types\n",
138157
"from MaxText import pyconfig\n",
139-
"from MaxText.input_pipeline import _input_pipeline_utils\n",
158+
"from maxtext.input_pipeline import input_pipeline_utils\n",
140159
"from maxtext.checkpoint_conversion import to_maxtext\n",
141160
"from maxtext.inference import inference_utils\n",
142161
"from maxtext.utils import maxtext_utils\n",
143162
"from maxtext.utils import max_logging\n",
163+
"from maxtext.utils import model_creation_utils\n",
164+
"\n",
165+
"try:\n",
166+
" from google.colab import userdata\n",
167+
"except ModuleNotFoundError:\n",
168+
" userdata = os.environ\n",
144169
"\n",
145-
"from google.colab import userdata\n",
146170
"from huggingface_hub import login\n",
147171
"\n",
148172
"MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)\n",
149173
"MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n",
150174
"MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n",
175+
"MAXTEXT_CONFIGS_DIR = os.path.join(MAXTEXT_REPO_ROOT, \"src/maxtext/configs\")\n",
151176
"\n",
152177
"nest_asyncio.apply()"
153178
]
@@ -291,7 +316,7 @@
291316
"metadata": {},
292317
"outputs": [],
293318
"source": [
294-
"model = mt.from_config(config)\n",
319+
"model = model_creation_utils.from_config(config)\n",
295320
"mesh = model.mesh\n",
296321
"init_rng = jax.random.PRNGKey(config.init_weights_seed)\n",
297322
"state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)\n",
@@ -313,7 +338,7 @@
313338
"metadata": {},
314339
"outputs": [],
315340
"source": [
316-
"tokenizer = _input_pipeline_utils.get_tokenizer(\n",
341+
"tokenizer = input_pipeline_utils.get_tokenizer(\n",
317342
" f\"{MAXTEXT_ASSETS_ROOT}/tokenizers/qwen3-tokenizer\",\n",
318343
" \"huggingface\",\n",
319344
" add_bos=True,\n",

0 commit comments

Comments
 (0)