Skip to content

Commit c9f493e

Browse files
committed
Make tokenizer_path to be non-mandatory
1 parent 4ea3f7a commit c9f493e

13 files changed

Lines changed: 161 additions & 84 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ To extend conversion support to a new model architecture, you must define its sp
221221
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
222222

223223
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
224-
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py), add the new model key in `HF_IDS`.
224+
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
225225
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
226226

227227
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)

docs/tutorials/posttraining/rl.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,17 @@ install_maxtext_tpu_post_train_extra_deps
8686

8787
## Setup environment variables
8888

89+
Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using
90+
91+
```bash
92+
huggingface-cli login
93+
```
94+
8995
Setup following environment variables before running GRPO/GSPO:
9096

9197
```bash
9298
# -- Model configuration --
93-
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-8b-Instruct'
94-
export MODEL=<MaxText Model> # e.g. 'llama3.1-8b'
95-
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-8B-Instruct'
96-
export HF_TOKEN=<Hugging Face access token>
99+
export MODEL=<MaxText Model> # e.g. 'llama3.1-8b-Instruct'
97100

98101
# -- MaxText configuration --
99102
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
@@ -135,11 +138,9 @@ Run the following command for GRPO:
135138
```
136139
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
137140
model_name=${MODEL?} \
138-
tokenizer_path=${TOKENIZER?} \
139141
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
140142
run_name=${RUN_NAME?} \
141143
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
142-
hf_access_token=${HF_TOKEN?} \
143144
chips_per_vm=${CHIPS_PER_VM?}
144145
```
145146

@@ -159,11 +160,9 @@ Run the following command for GSPO:
159160
```
160161
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
161162
model_name=${MODEL?} \
162-
tokenizer_path=${TOKENIZER?} \
163163
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
164164
run_name=${RUN_NAME?} \
165165
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
166-
hf_access_token=${HF_TOKEN?} \
167166
loss_algo=gspo-token \
168167
chips_per_vm=${CHIPS_PER_VM?}
169168
```

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ actual values.
6969

7070
```bash
7171
# -- Model configuration --
72-
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-70b-Instruct'
73-
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b'
74-
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-70B-Instruct'
72+
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b-Instruct'
7573
export HF_TOKEN=<Hugging Face access token>
7674

7775
# -- MaxText configuration --
@@ -200,7 +198,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \
200198
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
201199
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
202200
model_name=${MODEL?} \
203-
tokenizer_path=${TOKENIZER?} \
204201
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
205202
run_name=${WORKLOAD?} \
206203
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
@@ -218,7 +215,6 @@ xpk workload create-pathways --workload ${WORKLOAD?} \
218215
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
219216
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
220217
model_name=${MODEL?} \
221-
tokenizer_path=${TOKENIZER?} \
222218
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
223219
run_name=${WORKLOAD?} \
224220
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \

docs/tutorials/posttraining/sft.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,17 @@ install_maxtext_tpu_post_train_extra_deps
4343

4444
## Setup environment variables
4545

46+
Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using
47+
48+
```bash
49+
huggingface-cli login
50+
```
51+
4652
Set the following environment variables before running SFT.
4753

4854
```sh
4955
# -- Model configuration --
50-
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
51-
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
52-
export HF_TOKEN=<Hugging Face access token>
56+
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b-Instruct'
5357

5458
# -- MaxText configuration --
5559
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
@@ -93,8 +97,6 @@ python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_tr
9397
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
9498
model_name=${PRE_TRAINED_MODEL?} \
9599
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \
96-
hf_access_token=${HF_TOKEN?} \
97-
tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER?} \
98100
per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
99101
steps=${STEPS?} \
100102
hf_path=${DATASET_NAME?} \

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ export HF_TOKEN=<Hugging Face Access Token>
9595

9696
# -- Model Configuration --
9797
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
98-
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3
9998

10099
# -- Dataset configuration --
101100
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
@@ -143,7 +142,7 @@ xpk workload create \
143142
--workload=${WORKLOAD_NAME?} \
144143
--tpu-type=${TPU_TYPE?} \
145144
--num-slices=${TPU_SLICE?} \
146-
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
145+
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
147146
```
148147

149148
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
@@ -159,7 +158,7 @@ xpk workload create-pathways \
159158
--workload=${WORKLOAD_NAME?} \
160159
--tpu-type=${TPU_TYPE?} \
161160
--num-slices=${TPU_SLICE?} \
162-
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
161+
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
163162
```
164163

165164
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.

src/maxtext/checkpoint_conversion/compare_hf_ckpt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
from safetensors import safe_open
4949

5050
from maxtext.configs import pyconfig
51-
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, print_ram_usage, get_hf_model
51+
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model
5252
from maxtext.utils import max_logging
53+
from maxtext.utils.globals import HF_IDS
5354

5455

5556
jax.config.update("jax_platform_name", "cpu")

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
Key Parameters (to be set in the config file or as command-line overrides):
2121
model_name: (Required) The name of the model to convert (e.g., "gemma2-2b").
22-
Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`.
22+
Must be a key in `maxtext.utils.globals.HF_IDS`.
2323
load_parameters_path: (Required) Path to the MaxText checkpoint directory
2424
containing the parameter-only checkpoint.
2525
base_output_directory: (Optional) The directory where the converted HuggingFace
@@ -79,12 +79,13 @@
7979
save_model_files,
8080
load_orbax_checkpoint,
8181
detect_and_extract_checkpoint,
82-
HF_IDS,
8382
MemoryMonitorTqdm,
8483
print_peak_memory,
8584
)
8685
from maxtext.utils import max_logging
8786
from maxtext.utils import max_utils
87+
from maxtext.utils.globals import HF_IDS
88+
8889

8990
flags.DEFINE_bool(
9091
"override_model_architecture",

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
Key Parameters (to be set in the config file or as command-line overrides):
2020
model_name: (Required) The name of the model to convert (e.g., "gemma2-2b").
21-
Must be a key in `maxtext.checkpoint_conversion.utils.utils.HF_IDS`.
21+
Must be a key in `maxtext.utils.globals.HF_IDS`.
2222
base_output_directory: (Optional) The directory where the converted HuggingFace
2323
checkpoint will be saved. Can be a local path, a GCS
2424
path (gs://...), or a HuggingFace Hub repo ID (hf://...).
@@ -30,7 +30,7 @@
3030
Defaults to False.
3131
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
3232
If unspecified, we use the default Hugging Face repository ID
33-
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
33+
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `maxtext.utils.globals`).
3434
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3535
3636
Environment Variables:
@@ -74,11 +74,12 @@
7474
from maxtext.common.common_types import MODEL_MODE_TRAIN
7575
from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
7676
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
77-
from maxtext.checkpoint_conversion.utils.utils import HF_IDS, MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
77+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys
7878
from maxtext.inference.inference_utils import str2bool
7979
from maxtext.layers import quantizations
8080
from maxtext.models import models
8181
from maxtext.utils import max_logging, max_utils, maxtext_utils
82+
from maxtext.utils.globals import HF_IDS
8283
import numpy as np
8384
from orbax.checkpoint import type_handlers
8485
from safetensors import safe_open
@@ -155,7 +156,12 @@ def _initialize_index(self):
155156
if self.is_local:
156157
index_path = os.path.join(self.model_id, index_file)
157158
else:
158-
index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token, revision=self.revision)
159+
index_path = hf_hub_download(
160+
repo_id=self.model_id,
161+
filename=index_file,
162+
token=self.token,
163+
revision=self.revision,
164+
)
159165
with open(index_path, "r", encoding="utf-8") as f:
160166
index_data = json.load(f)
161167
self.shard_map = index_data["weight_map"]
@@ -185,7 +191,12 @@ def get_tensor(self, key: str) -> np.ndarray:
185191
else:
186192
# STEP 1: Download outside the lock.
187193
# multiple threads can download different shards at the same time.
188-
local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token, revision=self.revision)
194+
local_path = hf_hub_download(
195+
repo_id=self.model_id,
196+
filename=shard_name,
197+
token=self.token,
198+
revision=self.revision,
199+
)
189200

190201
# STEP 2: Lock ONLY the reading into RAM.
191202
# This prevents multiple threads from simultaneously allocating large chunks of RAM.
@@ -200,7 +211,13 @@ class LazyTensor:
200211
and transformation until __array__ is called (e.g., by Orbax during save).
201212
"""
202213

203-
def __init__(self, load_fn: Callable[[], np.ndarray], shape: tuple, dtype, name: str = "unknown"):
214+
def __init__(
215+
self,
216+
load_fn: Callable[[], np.ndarray],
217+
shape: tuple,
218+
dtype,
219+
name: str = "unknown",
220+
):
204221
self._load_fn = load_fn
205222
self.shape = shape
206223
self.dtype = np.dtype(dtype)
@@ -421,7 +438,13 @@ def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_t
421438
def _loader(getter, key, shape, hook):
422439
return apply_hook_fns(getter(key), shape, hook)
423440

424-
load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_or_shapes, hook_fn)
441+
load_fn = partial(
442+
_loader,
443+
tensor_getter,
444+
hf_source_keys_or_key,
445+
mt_target_shape_or_shapes,
446+
hook_fn,
447+
)
425448
# Stacked mapping
426449
elif not isinstance(hf_source_keys_or_key[0], list):
427450
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
@@ -516,7 +539,12 @@ def _get_maxtext_weight(
516539
# to load the tensor later (the `load_fn`, shape, dtype).
517540
# The actual data will only be loaded when Orbax calls `__array__`
518541
# on this object during the saving process.
519-
final_mt_tensor_numpy = LazyTensor(load_fn, mt_target_shape_or_shapes, config.weight_dtype, name=mt_param_key_or_keys)
542+
final_mt_tensor_numpy = LazyTensor(
543+
load_fn,
544+
mt_target_shape_or_shapes,
545+
config.weight_dtype,
546+
name=mt_param_key_or_keys,
547+
)
520548
if not is_composite_mt_key:
521549
# Case 2.1: Lazy mode, `atomic_mt_key`
522550
final_mt_weights[mt_target_idx_or_indices] = final_mt_tensor_numpy
@@ -562,7 +590,10 @@ def main(
562590

563591
# check the supported model ids
564592
if model_name_original not in HF_IDS:
565-
raise ValueError(f"Unsupported model name: {model_name_original}. Supported models are: {list(HF_IDS.keys())}")
593+
raise ValueError(
594+
f"Unsupported model name: {model_name_original}.\
595+
Supported models are: {list(HF_IDS.keys())}"
596+
)
566597

567598
model_id = hf_model_path or HF_IDS[model_name_original]
568599

@@ -633,7 +664,11 @@ def _eager_getter(key):
633664
filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys())
634665

635666
for mt_param_key_or_keys in MemoryMonitorTqdm(
636-
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
667+
filtered_map_keys,
668+
desc="Transforming weights",
669+
unit="param",
670+
leave=True,
671+
dynamic_ncols=True,
637672
):
638673
if not lazy_load_tensors:
639674
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")
@@ -651,7 +686,13 @@ def _eager_getter(key):
651686

652687
# Step 2: Determine the loading function for hf key
653688
# based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking)
654-
load_fn = _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config)
689+
load_fn = _get_hf_loading_function(
690+
hf_source_keys_or_key,
691+
tensor_getter,
692+
hook_fn,
693+
mt_target_shape_or_shapes,
694+
config,
695+
)
655696

656697
# Step 3: Load hf keys and convert to maxtext keys
657698
# based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`)
@@ -710,9 +751,13 @@ def _eager_getter(key):
710751
default=False,
711752
help="Whether to use lazy loading of HF tensors.",
712753
)
713-
# If not specified, default to maxtext.checkpoint_conversion.utils.utils.HF_IDS[model_name]
754+
# If not specified, default to maxtext.utils.globals.HF_IDS[model_name]
714755
parser.add_argument(
715-
"--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo"
756+
"--hf_model_path",
757+
type=str,
758+
required=False,
759+
default="",
760+
help="local path to hf model, or custom remote hf repo",
716761
)
717762
# Determines the logical sharding of the output checkpoint by partitioning
718763
# weights across virtual XLA devices.
@@ -730,7 +775,11 @@ def _eager_getter(key):
730775
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
731776

732777
parser.add_argument(
733-
"--revision", type=str, required=False, default=None, help="Specific Hugging Face revision (branch/tag/commit)"
778+
"--revision",
779+
type=str,
780+
required=False,
781+
default=None,
782+
help="Specific Hugging Face revision (branch/tag/commit)",
734783
)
735784

736785
# Parse local arguments

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -56,41 +56,6 @@
5656
DEFAULT_MAX_SHARD_SIZE = 1024 * 1024 * 1024 * 3 # 3GB default
5757

5858

59-
# Mapping from MaxText model key to Hugging Face tokenizer identifiers
60-
HF_IDS = {
61-
"gemma2-2b": "google/gemma-2-2b",
62-
"gemma2-9b": "google/gemma-2-9b",
63-
"gemma2-27b": "google/gemma-2-27b",
64-
"gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text
65-
"gemma3-12b": "google/gemma-3-12b-it",
66-
"gemma3-27b": "google/gemma-3-27b-it",
67-
"qwen3-0.6b": "Qwen/Qwen3-0.6B",
68-
"qwen3-4b": "Qwen/Qwen3-4B",
69-
"qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507",
70-
"qwen3-8b": "Qwen/Qwen3-8B",
71-
"qwen3-14b": "Qwen/Qwen3-14B",
72-
"qwen3-32b": "Qwen/Qwen3-32B",
73-
"llama3.1-8b": "meta-llama/Llama-3.1-8B",
74-
"llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
75-
"llama3.1-70b-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
76-
"llama3.1-70b": "meta-llama/Llama-3.1-70B",
77-
"llama3.1-405b": "meta-llama/Llama-3.1-405B",
78-
"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507",
79-
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507",
80-
"qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct",
81-
"deepseek3-671b": "deepseek-ai/DeepSeek-V3",
82-
"gpt-oss-20b": "openai/gpt-oss-20b",
83-
"gpt-oss-120b": "openai/gpt-oss-120b",
84-
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
85-
"qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct",
86-
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
87-
"mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1",
88-
"olmo3-7b": "allenai/Olmo-3-7B-Instruct",
89-
"olmo3-7b-pt": "allenai/Olmo-3-1025-7B",
90-
"olmo3-32b": "allenai/Olmo-3-32B-Think",
91-
}
92-
93-
9459
def _get_local_directory(output_dir: str) -> str:
9560
"""Determines the local directory for saving files."""
9661
if output_dir.startswith("gs://") or output_dir.startswith("hf://"):

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ num_vocab_tiling: 1
573573

574574
# Tokenizer
575575
vocab_size: 32_000 # powers of 2 for sharding
576-
tokenizer_path: "src/maxtext/assets/tokenizers/tokenizer.llama2"
576+
tokenizer_path: ""
577577
# tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken
578578
# grain pipeline supports tokenizer_type: sentencepiece, huggingface
579579
# hf pipeline only supports huggingface type, and will ignore tokenizer_type flag

0 commit comments

Comments
 (0)