Skip to content

[NPU / 910B megatron-swift lora 微调] TypeError: Invalid function argument. Expected parameter tensor of type torch.Tensor but got <class 'NoneType'> instead. #9107

@HorizonChaser

Description

@HorizonChaser

Checklist / 检查清单

  • I have searched existing issues, and this is a new bug report. / 我已经搜索过现有的 issues,确认这是一个新的 bug report。

Bug Description / Bug 描述

在 16x910B 上使用 megatron-swift 对 Qwen3-4B 进行 lora 微调 (target_module = all-linear) 时, 可以开始训练, 但是在反向传播时出现了如下错误 (以 rank4 为例)

[rank4]: Traceback (most recent call last):                                                                                                   
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/cli/_megatron/sft.py", line 7, in <module>                          
[rank4]:     megatron_sft_main()                                                                                                              
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/megatron/train/sft.py", line 87, in megatron_sft_main               [rank4]:     return MegatronSft(args).main()                                                                                                  [rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/llm/base.py", line 49, in main                                      
[rank4]:     result = self.run()                                                                                                              
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/megatron/train/sft.py", line 77, in run                             [rank4]:     self.trainer.train(train_dataset, val_dataset, data_collator)                                                                    [rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/megatron/trainers/base.py", line 1098, in train                     
[rank4]:     pretrain( 
[rank4]:   File "/opt/Megatron-LM/megatron/training/training.py", line 801, in pretrain
[rank4]:     iteration, num_floating_point_operations_so_far = train(
[rank4]:   File "/opt/Megatron-LM/megatron/training/training.py", line 1993, in train
[rank4]:     train_step(forward_step_func,
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/swift/megatron/trainers/base.py", line 565, in train_step
[rank4]:     return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
[rank4]:   File "/opt/Megatron-LM/megatron/training/training.py", line 1241, in train_step
[rank4]:     losses_reduced = forward_backward_func(
[rank4]:   File "/opt/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1983, in forward_backward_pipelining_without_interleaving
[rank4]:     config.finalize_model_grads_func(
[rank4]:   File "/opt/Megatron-LM/megatron/core/distributed/finalize_model_grads.py", line 296, in finalize_model_grads
[rank4]:     _allreduce_embedding_grads(model, config)
[rank4]:   File "/opt/Megatron-LM/megatron/core/distributed/finalize_model_grads.py", line 185, in _allreduce_embedding_grads
[rank4]:     _allreduce_word_embedding_grads(model, config)
[rank4]:   File "/opt/Megatron-LM/megatron/core/distributed/finalize_model_grads.py", line 150, in _allreduce_word_embedding_grads
[rank4]:     torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
[rank4]:   File "/opt/Mindspeed-230/mindspeed/core/megatron_basic/requirements_basic.py", line 85, in wrapper
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank4]:     return func(*args, **kwargs)
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2786, in all_reduce
[rank4]:     _check_single_tensor(tensor, "tensor")
[rank4]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1114, in _check_single_tensor
[rank4]:     raise TypeError(
[rank4]: TypeError: Invalid function argument. Expected parameter `tensor` of type torch.Tensor
[rank4]:              but got <class 'NoneType'> instead.

看起来是在 embedding 层梯度同步的问题, 但是并没有对 embedding 层进行微调, 因此怀疑是一个 bug.

加载的模型结构详见补充信息.

How to Reproduce / 如何复现

各组件版本 (均按照 https://swift.readthedocs.io/zh-cn/v3.12/BestPractices/NPU-support.html#id3 安装):

gpytorch                      1.14
mindspeed                     0.12.1      /opt/Mindspeed-230
ms_swift                      3.12.2
torch                         2.7.1+cpu
torch_npu                     2.7.1
torchaudio                    2.1.0+cpu
torchvision                   0.16.0+cpu

训练脚本:

#!/bin/bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

export PYTHONPATH=$PYTHONPATH:/opt/Megatron-LM
export MEGATRON_LM_PATH=/opt/Megatron-LM

TP_SIZE=4
PP_SIZE=2
EP_SIZE=1
CP_SIZE=1

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=160

MODEL_PATH="models/Qwen3-4B-Instruct-2507"
DATASET_PATH="train_3000.jsonl"

CURRENT_TIME=$(date +"%y%m%d_%H%M")
SAVE_PATH="ckpt/ckpt_lora_npu16_${CURRENT_TIME}"

TB_PATH="tb/lora_newmega_npu16_${CURRENT_TIME}"

export NPROC_PER_NODE=16

megatron sft \
    --model $MODEL_PATH \
    --dataset $DATASET_PATH \
    --save $SAVE_PATH \
    --check_model False \
    --load_safetensors True \
    --save_safetensors True \
    --load_from_cache_file True \
    --finetune True \
    --train_type lora \
    --target_modules all-linear \
    --lora_rank 32 \
    --lora_alpha 64 \
    --lora_dropout 0.05 \
    --no_save_optim True \
    --no_save_rng True \
    --tensor_model_parallel_size $TP_SIZE \
    --pipeline_model_parallel_size $PP_SIZE \
    --expert_model_parallel_size $EP_SIZE \
    --context_parallel_size $CP_SIZE \
    --sequence_parallel True \
    --use_precision_aware_optimizer True \
    --optimizer_cpu_offload True \
    --optimizer_offload_fraction 1 \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --micro_batch_size $MICRO_BATCH_SIZE \
    --global_batch_size $GLOBAL_BATCH_SIZE \
    --max_length 24208 \
    --packing True \
    --split_dataset_ratio 0.01 \
    --lr 1e-4 \
    --min_lr 1e-6 \
    --lr_warmup_fraction 0.05 \
    --train_iters 1000 \
    --max_epochs 1 \
    --log_interval 1 \
    --attention_backend auto \
    --cross_entropy_loss_fusion True \
    --eval_interval 500 \
    --save_interval 500 \
    --num_workers 8 \
    --dataset_num_proc 1 \
    --tensorboard_dir $TB_PATH \
    --no_gradient_accumulation_fusion True \
    --no_masked_softmax_fusion True \
    --megatron_extra_kwargs '{"use_fused_rmsnorm": false}'

Additional Information / 补充信息

加载后的模型结构:

[INFO:swift] model: PeftModelForCausalLM(                                                                                                      
  (base_model): LoraModel(                                                                                                                     
    (model): GPTModel(                                                                                                                         
      (embedding): LanguageModelEmbedding(                                                                                                     
        (word_embeddings): VocabParallelEmbedding()                                                                                            
        (embedding_dropout): Dropout(p=0.0, inplace=False)                                                                                     
      )                                                                                                                                        
      (rotary_pos_emb): RotaryEmbedding()                                                                                                      
      (decoder): TransformerBlock(                                                                                                             
        (layers): ModuleList(                                                                                                                  
          (0-17): 18 x TransformerLayer(                                                                                                       
            (input_layernorm): IdentityOp()                                                                                                    
            (self_attention): SelfAttention(                                                                                                   
              (core_attention): DotProductAttention(                                                                                           
                (scale_mask_softmax): FusedScaleMaskSoftmax()                                                                                  
                (attention_dropout): Dropout(p=0.0, inplace=False)                                                                             
              )                                                                                                                                
              (linear_proj): RowParallelLinear(in_features=4096, out_features=2560, bias=False, TP=4)                                          
              (linear_qkv): LoraParallelLinear(                                                                                                
                (base_layer): MindSpeedTELayerNormColumnParallelLinear()                                                                       
                (lora_dropout): ModuleDict(                                                                                                    
                  (default): Dropout(p=0.05, inplace=False)                                                                                    
                )                                                                                                                              
                (lora_A): ModuleDict(                                                                                                          
                  (default): Linear(in_features=2560, out_features=32, bias=False)                                                             
                )                                                                                                                              
                (lora_B): ModuleDict(                                                                                                          
                  (default): ColumnParallelLinear(in_features=32, out_features=6144, bias=False, TP=4)                                         
                )                                                                                                                              
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (q_layernorm): RMSNorm()
              (k_layernorm): RMSNorm()
            )
            (pre_cross_attn_layernorm): IdentityOp()
            (cross_attention): IdentityOp()
            (cross_attn_bda): IdentityFuncOp()
            (pre_mlp_layernorm): IdentityOp()
            (mlp): MLP(
              (linear_fc1): LoraParallelLinear(
                (base_layer): MindSpeedTELayerNormColumnParallelLinear()
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): ColumnParallelLinear(in_features=32, out_features=19456, bias=False, TP=4)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (linear_fc2): RowParallelLinear(in_features=9728, out_features=2560, bias=False, TP=4)
            )
          )
        )
      )
    )
  )
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions