@@ -113,16 +113,10 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
113113 ]
114114
115115 # Vision layers mapping
116- if scan_layers :
117- for i in range (Nvis ):
118- for mx , hf in vision_params :
119- key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-{ mx } "
120- mapping [key ] = f"model.vision_tower.vision_model.encoder.layers.{ i } .{ hf } "
121- else :
122- for i in range (Nvis ):
123- for mx , hf in vision_params :
124- key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{ i } -{ mx } "
125- mapping [key ] = f"model.vision_tower.vision_model.encoder.layers.{ i } .{ hf } "
116+ for i in range (Nvis ):
117+ for mx , hf in vision_params :
118+ key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{ i } -{ mx } "
119+ mapping [key ] = f"model.vision_tower.vision_model.encoder.layers.{ i } .{ hf } "
126120
127121 # Text decoder mapping
128122 text_params = [
@@ -142,9 +136,26 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
142136 ]
143137
144138 if scan_layers :
145- for mx , hf in text_params :
146- key = f"params-decoder-layers-{ mx } "
147- mapping [key ] = [f"model.language_model.layers.{ i } .{ hf } " for i in range (Ndec )]
139+ # Gemma3 repeats a 6-layer attention pattern (5 local + 1 global),
140+ # scanned as layers_0..layers_5 with leftovers in layers_remainder.
141+ attention_pattern_length = 6
142+ num_remaining = Ndec % attention_pattern_length
143+ num_scanned = Ndec - num_remaining
144+
145+ # Main scanned blocks: params-decoder-layers-layers_{block_idx}-{param}
146+ for block_idx in range (attention_pattern_length ):
147+ hf_indices = list (range (block_idx , num_scanned , attention_pattern_length ))
148+ for mx , hf in text_params :
149+ key = f"params-decoder-layers-layers_{ block_idx } -{ mx } "
150+ mapping [key ] = [f"model.language_model.layers.{ i } .{ hf } " for i in hf_indices ]
151+
152+ # Remainder layers (unscanned): params-decoder-layers_remainder-layers_{rem_idx}-{param}
153+ if num_remaining > 0 :
154+ for rem_idx in range (num_remaining ):
155+ hf_layer_idx = num_scanned + rem_idx
156+ for mx , hf in text_params :
157+ key = f"params-decoder-layers_remainder-layers_{ rem_idx } -{ mx } "
158+ mapping [key ] = f"model.language_model.layers.{ hf_layer_idx } .{ hf } "
148159 else :
149160 for i in range (Ndec ):
150161 for mx , hf in text_params :
@@ -262,9 +273,17 @@ def pos_embed(x, target_shape):
262273 # Text layers
263274 tc = config .get ("text_config" , {})
264275 nlayers = tc .get ("num_hidden_layers" , 0 )
265- layer_ids = [None ] if scan_layers else list (range (nlayers ))
266- for i in layer_ids :
267- pref = f"params-decoder-layers_{ i } -" if i is not None else "params-decoder-layers-"
276+ if scan_layers :
277+ attention_pattern_length = 6
278+ num_remaining = nlayers % attention_pattern_length
279+ # Scanned sub-layer prefixes
280+ prefixes = [f"params-decoder-layers-layers_{ block_idx } -" for block_idx in range (attention_pattern_length )]
281+ # Remainder sub-layer prefixes
282+ if num_remaining > 0 :
283+ prefixes += [f"params-decoder-layers_remainder-layers_{ rem_idx } -" for rem_idx in range (num_remaining )]
284+ else :
285+ prefixes = [f"params-decoder-layers_{ i } -" for i in range (nlayers )]
286+ for pref in prefixes :
268287 # Attention Q/K/V/O
269288 hooks [pref + "self_attention-query-kernel" ] = reshape_kernel
270289 hooks [pref + "self_attention-key-kernel" ] = reshape_kernel
@@ -288,13 +307,8 @@ def pos_embed(x, target_shape):
288307 # Vision layers
289308 vc = config .get ("vision_config" , {})
290309 nvis = vc .get ("num_hidden_layers" , 0 )
291- vision_layer_ids = [None ] if scan_layers else list (range (nvis ))
292- for i in vision_layer_ids :
293- base = (
294- f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{ i } -"
295- if i is not None
296- else "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-"
297- )
310+ for i in range (nvis ):
311+ base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{ i } -"
298312 # Attention kernels & biases
299313 for qkv in ["query" , "key" , "value" ]:
300314 hooks [base + f"MultiHeadDotProductAttention_0-{ qkv } -kernel" ] = reshape_kernel
0 commit comments