跳到主要内容

SFT 监督微调

问题

什么是 SFT?监督微调的流程和关键点是什么?

答案

SFT(Supervised Fine-Tuning) 是使用标注好的 (instruction, response) 数据对,以监督学习方式微调预训练模型。

一、SFT 训练流程

二、数据格式

[
{
"instruction": "将以下英文翻译成中文",
"input": "Hello, how are you?",
"output": "你好,你怎么样?"
},
{
"instruction": "解释什么是递归",
"input": "",
"output": "递归是函数调用自身的编程技巧..."
}
]

对话格式(ChatML)

{
"messages": [
{"role": "system", "content": "你是一个有用的助手"},
{"role": "user", "content": "什么是闭包?"},
{"role": "assistant", "content": "闭包是指函数可以访问其词法作用域外的变量..."}
]
}

三、训练代码示例(使用 Transformers)

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import load_dataset

# 1. 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

# 2. 加载数据集
dataset = load_dataset("json", data_files="sft_data.json")

# 3. 配置训练参数
training_args = TrainingArguments(
output_dir="./sft-output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # 等效 batch_size = 16
learning_rate=2e-5,
warmup_ratio=0.1,
logging_steps=10,
save_strategy="epoch",
bf16=True, # 混合精度训练
)

# 4. SFT Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
tokenizer=tokenizer,
max_seq_length=2048,
)

# 5. 开始训练
trainer.train()

四、关键超参数

参数推荐值说明
learning_rate1e-5 ~ 5e-5太大遗忘预训练知识,太小学不到新知识
epochs2-5过多会过拟合
batch_size16-64越大越稳定,受 GPU 内存限制
warmup_ratio0.05-0.1学习率热身比例
max_seq_length根据数据覆盖 95% 样本长度即可
灾难性遗忘

SFT 最大风险是灾难性遗忘——模型在学习新任务时忘记预训练能力。防御方法:

  • 低学习率(1e-5 ~ 2e-5)
  • 少 Epoch(2-3)
  • 混入通用数据(10-20%)

常见面试问题

Q1: SFT 和预训练的区别?

答案

维度预训练SFT
目标学习语言能力学习遵循指令
数据海量无标注文本标注的 instruction-response 对
规模万亿 Token千-百万条数据
成本极高(数百万美元)中等(数百-数千美元)
学习率较高较低(避免遗忘)

Q2: SFT 数据量多少合适?

答案

  • 最小量:1000 条高质量数据可以看到明显效果
  • 推荐量:10K-100K 条
  • 关键判断:数据质量远比数量重要——1000 条精标数据 > 10 万条低质量数据

相关链接