Skip to content

Commit 9f33cb4

Browse files
committed
fix: Gemma3 scan_layers=True checkpoint conversion param mapping
1 parent f70f5c8 commit 9f33cb4

3 files changed

Lines changed: 47 additions & 34 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The following models are supported:
99
| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF |
1010
| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: |
1111
| **Gemma2** | 2B, 9B, 27B |||||
12-
| **Gemma3** (Multimodal) | 4B, 12B, 27B | - || - ||
12+
| **Gemma3** (Multimodal) | 4B, 12B, 27B | || ||
1313
| **Llama3.1** | 8B, 70B, 450B |||||
1414
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B |||||
1515
| **Qwen3 MoE** | 30B, 235B, 480B |||||

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,10 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
113113
]
114114

115115
# Vision layers mapping
116-
if scan_layers:
117-
for i in range(Nvis):
118-
for mx, hf in vision_params:
119-
key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-{mx}"
120-
mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
121-
else:
122-
for i in range(Nvis):
123-
for mx, hf in vision_params:
124-
key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}"
125-
mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
116+
for i in range(Nvis):
117+
for mx, hf in vision_params:
118+
key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}"
119+
mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
126120

127121
# Text decoder mapping
128122
text_params = [
@@ -142,9 +136,26 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
142136
]
143137

144138
if scan_layers:
145-
for mx, hf in text_params:
146-
key = f"params-decoder-layers-{mx}"
147-
mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in range(Ndec)]
139+
# Gemma3 repeats a 6-layer attention pattern (5 local + 1 global),
140+
# scanned as layers_0..layers_5 with leftovers in layers_remainder.
141+
attention_pattern_length = 6
142+
num_remaining = Ndec % attention_pattern_length
143+
num_scanned = Ndec - num_remaining
144+
145+
# Main scanned blocks: params-decoder-layers-layers_{block_idx}-{param}
146+
for block_idx in range(attention_pattern_length):
147+
hf_indices = list(range(block_idx, num_scanned, attention_pattern_length))
148+
for mx, hf in text_params:
149+
key = f"params-decoder-layers-layers_{block_idx}-{mx}"
150+
mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in hf_indices]
151+
152+
# Remainder layers (unscanned): params-decoder-layers_remainder-layers_{rem_idx}-{param}
153+
if num_remaining > 0:
154+
for rem_idx in range(num_remaining):
155+
hf_layer_idx = num_scanned + rem_idx
156+
for mx, hf in text_params:
157+
key = f"params-decoder-layers_remainder-layers_{rem_idx}-{mx}"
158+
mapping[key] = f"model.language_model.layers.{hf_layer_idx}.{hf}"
148159
else:
149160
for i in range(Ndec):
150161
for mx, hf in text_params:
@@ -262,9 +273,17 @@ def pos_embed(x, target_shape):
262273
# Text layers
263274
tc = config.get("text_config", {})
264275
nlayers = tc.get("num_hidden_layers", 0)
265-
layer_ids = [None] if scan_layers else list(range(nlayers))
266-
for i in layer_ids:
267-
pref = f"params-decoder-layers_{i}-" if i is not None else "params-decoder-layers-"
276+
if scan_layers:
277+
attention_pattern_length = 6
278+
num_remaining = nlayers % attention_pattern_length
279+
# Scanned sub-layer prefixes
280+
prefixes = [f"params-decoder-layers-layers_{block_idx}-" for block_idx in range(attention_pattern_length)]
281+
# Remainder sub-layer prefixes
282+
if num_remaining > 0:
283+
prefixes += [f"params-decoder-layers_remainder-layers_{rem_idx}-" for rem_idx in range(num_remaining)]
284+
else:
285+
prefixes = [f"params-decoder-layers_{i}-" for i in range(nlayers)]
286+
for pref in prefixes:
268287
# Attention Q/K/V/O
269288
hooks[pref + "self_attention-query-kernel"] = reshape_kernel
270289
hooks[pref + "self_attention-key-kernel"] = reshape_kernel
@@ -288,13 +307,8 @@ def pos_embed(x, target_shape):
288307
# Vision layers
289308
vc = config.get("vision_config", {})
290309
nvis = vc.get("num_hidden_layers", 0)
291-
vision_layer_ids = [None] if scan_layers else list(range(nvis))
292-
for i in vision_layer_ids:
293-
base = (
294-
f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
295-
if i is not None
296-
else "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-"
297-
)
310+
for i in range(nvis):
311+
base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
298312
# Attention kernels & biases
299313
for qkv in ["query", "key", "value"]:
300314
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel

tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,15 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MA
4040

4141
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items
4242

43-
# # To get scanned ckpt, flip the scan_layers.
44-
# ToDo: gemma3 multimodal scanned ckpt
45-
# python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
46-
# model_name=${MODEL_NAME} \
47-
# hf_access_token=${HF_TOKEN} \
48-
# base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \
49-
# use_multimodal=${USE_MULTIMODAL} \
50-
# scan_layers=true
51-
52-
# export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx}/0/items
43+
# To get scanned ckpt, flip the scan_layers.
44+
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
45+
model_name=${MODEL_NAME} \
46+
hf_access_token=${HF_TOKEN} \
47+
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \
48+
use_multimodal=${USE_MULTIMODAL} \
49+
scan_layers=true
50+
51+
export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx}/0/items
5352

5453
# We also test whether the forward pass logits match the original HF model
5554
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`

0 commit comments

Comments
 (0)