|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# This logical rule is designed to optimize pipeline parallelism for large-scale jobs. |
| 16 | +# Key changes include removing expert weight sharding on the `q_lora` dimension, which |
| 17 | +# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when |
| 18 | +# EP x FSDP > 512. |
| 19 | +# |
| 20 | +# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a |
| 21 | +# data parallel (DP) domain externally, making the `data` axis a necessary reference; |
| 22 | +# second, it may be required for DCN communication. |
| 23 | +# |
| 24 | +# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or |
| 25 | +# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to |
| 26 | +# store prefetched weights. |
| 27 | +mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert'] |
| 28 | +data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']] |
| 29 | +logical_axis_rules: [ |
| 30 | + ['activation_batch', ['data', 'fsdp', 'expert']], |
| 31 | + ['activation_batch_no_exp', ['data', 'fsdp']], |
| 32 | + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], |
| 33 | + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']], |
| 34 | + ['activation_heads', ['tensor']], |
| 35 | + ['activation_kv_heads', ['tensor']], |
| 36 | + ['activation_length', ['expert']], |
| 37 | + ['activation_attn_length', ['expert']], |
| 38 | + ['activation_q_length', ['expert']], |
| 39 | + ['activation_attn_embed', ['tensor']], |
| 40 | + ['activation_embed', ['tensor']], |
| 41 | + ['activation_mlp', ['tensor']], |
| 42 | + ['activation_kv', ['tensor']], |
| 43 | + ['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']], |
| 44 | + ['activation_kv_batch', ['data', 'fsdp', 'expert']], |
| 45 | + ['activation_kv_batch_no_exp', ['data', 'fsdp']], |
| 46 | + ['activation_kv_head_dim', ['tensor']], |
| 47 | + ['activation_vocab', ['tensor']], |
| 48 | + ['activation_stage', 'stage'], |
| 49 | + ['activation_exp', ['expert']], |
| 50 | + ['decode_batch', ['data', 'fsdp', 'expert']], |
| 51 | + ['mlp', ['tensor']], |
| 52 | + ['mlp_no_fsdp', ['tensor']], |
| 53 | + ['vocab', ['tensor']], |
| 54 | + ['heads', ['tensor']], |
| 55 | + ['q_heads', ['tensor']], |
| 56 | + ['kv_heads', ['tensor']], |
| 57 | + ['embed', ['fsdp', 'expert']], |
| 58 | + ['embed_no_exp', ['fsdp']], |
| 59 | + ['q_lora', ['fsdp']], |
| 60 | + ['kv_lora', ['fsdp']], |
| 61 | + ['norm', ['tensor']], |
| 62 | + ['layers', 'stage'], |
| 63 | + ['cache_heads', ['tensor']], |
| 64 | + ['exp', 'expert'], |
| 65 | + ['exp_with_fsdp', 'fsdp'], |
| 66 | + ['paged_kv_heads', ['tensor']], |
| 67 | + ['engram_dim', ['tensor']], |
| 68 | + ] |
0 commit comments