Skip to content

Commit ee60bdb

Browse files
authored
Merge branch 'PaddlePaddle:develop' into dev_mtp_mask
2 parents 34495c8 + 94a7d3d commit ee60bdb

25 files changed

Lines changed: 1217 additions & 397 deletions

.github/workflows/check-bypass.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
on:
2+
workflow_call:
3+
inputs:
4+
workflow-name:
5+
required: true
6+
type: string
7+
secrets:
8+
github-token:
9+
required: true
10+
outputs:
11+
can-skip:
12+
description: "Whether the workflow can be skipped."
13+
value: ${{ jobs.check-bypass.outputs.can-skip }}
14+
15+
jobs:
16+
check-bypass:
17+
name: Check bypass
18+
runs-on: ubuntu-latest
19+
permissions:
20+
contents: read
21+
env:
22+
CI_TEAM_MEMBERS: '["swgu98", "risemeup1" , "XieYunshen", "tianlef"]'
23+
outputs:
24+
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
25+
steps:
26+
- id: check-bypass
27+
name: Check Bypass
28+
uses: PFCCLab/ci-bypass@v2
29+
with:
30+
github-token: ${{ secrets.GITHUB_TOKEN }}
31+
non-pull-request-event-strategy: 'never-skipped'
32+
type: 'composite'
33+
composite-rule: |
34+
{
35+
"any": [
36+
{
37+
"type": "labeled",
38+
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
39+
"username": ${{ env.CI_TEAM_MEMBERS }}
40+
},
41+
{
42+
"type": "commented",
43+
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
44+
"username": ${{ env.CI_TEAM_MEMBERS }}
45+
}
46+
]
47+
}

.github/workflows/cherry-pick.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ jobs:
3838
PR_BODY: ${{ github.event.pull_request.body }}
3939
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
4040
MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }}
41-
BOT_USERNAME: ShigureNyako
42-
BOT_EMAIL: shigure_nyako@outlook.com
43-
REPO_NAME: ShigureNyako/PaddleFormers
41+
BOT_USERNAME: risemeup1111
42+
REPO_NAME: risemeup1111/PaddleFormers
4443
run: |
4544
# Function to post comment
4645
post_comment() {

.github/workflows/fleet-model-test.yml

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,19 @@ defaults:
2727
shell: bash
2828

2929
jobs:
30-
check_skip:
31-
name: Check skip-fleet-models-ci label
32-
runs-on: ubuntu-latest
33-
34-
outputs:
35-
skip: ${{ steps.check_skip.outputs.skip }}
36-
37-
steps:
38-
- name: Check skip-fleet-models-ci label
39-
id: check_skip
40-
shell: bash
41-
run: |
42-
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
43-
labels='${{ toJson(github.event.pull_request.labels.*.name) }}'
44-
echo "PR labels: $labels"
45-
46-
if echo "$labels" | grep -q "skip-fleet-models-ci"; then
47-
echo "skip=true" >> "$GITHUB_OUTPUT"
48-
else
49-
echo "skip=false" >> "$GITHUB_OUTPUT"
50-
fi
51-
else
52-
echo "skip=false" >> "$GITHUB_OUTPUT"
53-
fi
54-
55-
- name: Skip CI but mark success
56-
if: steps.check_skip.outputs.skip == 'true'
57-
run: |
58-
echo "skip-fleet-models-ci label found"
59-
echo "Downstream GPU jobs will be skipped"
30+
check-bypass:
31+
name: Check bypass
32+
if: ${{ inputs.can-skip != 'true' }}
33+
uses: ./.github/workflows/check-bypass.yml
34+
with:
35+
workflow-name: 'fleet-model-test'
36+
secrets:
37+
github-token: ${{ secrets.GITHUB_TOKEN }}
38+
6039

6140
check_documents_type:
62-
needs: check_skip
63-
if: ${{ needs.check_skip.outputs.skip == 'false' }}
41+
needs: check-bypass
42+
if: ${{ needs.check-bypass.outputs.can-skip == 'false' }}
6443
name: check documents type for pull request
6544
runs-on: ubuntu-latest
6645
env:
@@ -103,8 +82,8 @@ jobs:
10382
echo "is_md_only: $(cat $GITHUB_OUTPUT | grep is_md_only || echo '未找到')"
10483
10584
integration-test-H20-single-card:
106-
needs: [check_documents_type, check_skip]
107-
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check_skip.outputs.skip == 'false' }}
85+
needs: [check_documents_type, check-bypass]
86+
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check-bypass.outputs.can-skip == 'false' }}
10887
name: Integration test (H20, single card)
10988
runs-on:
11089
group: Fleet-H-single-card
@@ -266,8 +245,8 @@ jobs:
266245
267246
268247
integration-test-H20-multi-card:
269-
needs: [check_documents_type, check_skip]
270-
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check_skip.outputs.skip == 'false' }}
248+
needs: [check_documents_type, check-bypass]
249+
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check-bypass.outputs.can-skip == 'false' }}
271250
name: Integration test (H20, multi-card)
272251
runs-on:
273252
group: Fleet-H-multi-card
@@ -548,8 +527,8 @@ jobs:
548527
549528

550529
integration-test-a100:
551-
needs: [check_documents_type, check_skip]
552-
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check_skip.outputs.skip == 'false' }}
530+
needs: [check_documents_type, check-bypass]
531+
if: ${{ needs.check_documents_type.outputs.is_md_only == 'false' && needs.check-bypass.outputs.can-skip == 'false' }}
553532
name: Integration test (A100)
554533
runs-on:
555534
group: Distribute
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
name: Remove Skip-CI Labels
2+
3+
on:
4+
pull_request_target:
5+
types: [synchronize]
6+
7+
permissions:
8+
pull-requests: write
9+
10+
jobs:
11+
remove-skip-ci-labels:
12+
name: Remove skip-ci labels on new commits
13+
runs-on: ubuntu-latest
14+
steps:
15+
- name: Get PR labels
16+
id: get-labels
17+
uses: actions/github-script@v8
18+
with:
19+
github-token: ${{ secrets.GITHUB_TOKEN }}
20+
script: |
21+
const { data: labels } = await github.rest.issues.listLabelsOnIssue({
22+
owner: context.repo.owner,
23+
repo: context.repo.repo,
24+
issue_number: context.issue.number
25+
});
26+
27+
const skipCiLabels = labels
28+
.filter(label => label.name.startsWith('skip-ci:'))
29+
.map(label => label.name);
30+
31+
console.log('Found skip-ci labels:', skipCiLabels);
32+
core.setOutput('skip-ci-labels', JSON.stringify(skipCiLabels));
33+
core.setOutput('has-skip-ci-labels', skipCiLabels.length > 0 ? 'true' : 'false');
34+
35+
- name: Remove skip-ci labels
36+
if: steps.get-labels.outputs.has-skip-ci-labels == 'true'
37+
uses: actions/github-script@v8
38+
with:
39+
github-token: ${{ secrets.GITHUB_TOKEN }}
40+
script: |
41+
const skipCiLabels = JSON.parse('${{ steps.get-labels.outputs.skip-ci-labels }}');
42+
43+
for (const label of skipCiLabels) {
44+
console.log(`Removing label: ${label}`);
45+
await github.rest.issues.removeLabel({
46+
owner: context.repo.owner,
47+
repo: context.repo.repo,
48+
issue_number: context.issue.number,
49+
name: label
50+
});
51+
}
52+
53+
console.log(`Successfully removed ${skipCiLabels.length} skip-ci label(s)`);

examples/experiments/ernie_pretrain/ernie/src/callbacks/moe_correction_bias_adjust_callback.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def __init__(self, lr, use_sp):
3232
self.use_sp = use_sp
3333

3434
def on_optimizer_end(self, args, state, control, **kwargs):
35+
# Skip bias update when freeze_training is enabled
36+
if getattr(args, "freeze_training", False):
37+
logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.")
38+
return
39+
3540
model = kwargs["model"]
3641

3742
usages = {}

examples/experiments/ernie_pretrain/ernie/src/trainers/pretraining_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13001300
self._end_save_time = time.time()
13011301

13021302
def create_scheduler(self, num_training_steps):
1303+
# When freeze_training is enabled, use constant scheduler with lr=0
1304+
if getattr(self.args, "freeze_training", False):
1305+
logger.warning(
1306+
"WARNING: freeze_training is enabled! "
1307+
"Learning rate is set to 0 and model parameters will NOT be updated. "
1308+
"This mode is intended for debugging/profiling only, NOT for actual training."
1309+
)
1310+
from paddleformers.trainer.trainer_utils import get_constant_schedule
1311+
1312+
self.lr_scheduler = get_constant_schedule(learning_rate=0.0)
1313+
return self.lr_scheduler
1314+
13031315
if self.args.warmup_steps > 0:
13041316
warmup = self.args.warmup_steps
13051317
else:

paddleformers/cli/train/ernie_pretrain/src/callbacks/moe_correction_bias_adjust_callback.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def __init__(self, lr, use_sp):
3434
self.use_sp = use_sp
3535

3636
def on_optimizer_end(self, args, state, control, **kwargs):
37+
# Skip bias update when freeze_training is enabled
38+
if getattr(args, "freeze_training", False):
39+
logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.")
40+
return
41+
3742
model = kwargs["model"]
3843

3944
usages = {}

paddleformers/cli/train/ernie_pretrain/src/trainers/pretraining_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13141314
self._end_save_time = time.time()
13151315

13161316
def create_scheduler(self, num_training_steps):
1317+
# When freeze_training is enabled, use constant scheduler with lr=0
1318+
if getattr(self.args, "freeze_training", False):
1319+
logger.warning(
1320+
"WARNING: freeze_training is enabled! "
1321+
"Learning rate is set to 0 and model parameters will NOT be updated. "
1322+
"This mode is intended for debugging/profiling only, NOT for actual training."
1323+
)
1324+
from paddleformers.trainer.trainer_utils import get_constant_schedule
1325+
1326+
self.lr_scheduler = get_constant_schedule(learning_rate=0.0)
1327+
return self.lr_scheduler
1328+
13171329
if self.args.warmup_steps > 0:
13181330
warmup = self.args.warmup_steps
13191331
else:

paddleformers/data/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,14 @@
4646
"tolist",
4747
"DataCollatorForLanguageModeling",
4848
],
49-
"dist_dataloader": ["DummyDataset", "IterableDummyDataset", "DistDataLoader", "init_dataloader_comm_group"],
49+
"dist_dataloader": [
50+
"DummyDataset",
51+
"IterableDummyDataset",
52+
"DistDataLoader",
53+
"init_dataloader_comm_group",
54+
"StreamDistDataLoader",
55+
"init_stream_data_group",
56+
],
5057
"collate": ["Dict", "Pad", "Stack", "Tuple"],
5158
"vocab": ["Vocab"],
5259
"tokenizer": ["BaseTokenizer"],

0 commit comments

Comments
 (0)