鞍山制作网站哪家好,建设银行员工网站,租车网站制作方案,网站开发技术可行性分析怎么写STaR#xff08;Self-Taught Reasoner#xff09;方法#xff1a;让语言模型自学推理能力
在大型语言模型#xff08;LLM#xff09;的推理能力优化中#xff0c;STaR#xff08;Self-Taught Reasoner#xff09; 是一种引人注目的技术#xff0c;属于“修改提议分布…STaRSelf-Taught Reasoner方法让语言模型自学推理能力
在大型语言模型LLM的推理能力优化中STaRSelf-Taught Reasoner 是一种引人注目的技术属于“修改提议分布Modifying Proposal Distribution”类别。与传统的基于结果验证verifier方法不同STaR通过训练模型生成更好的推理步骤input-focused直接调整采样分布使其倾向于选择“推理相关”的token。本文将详细介绍STaR的原理、工作流程并提供一个可运行的Python代码实现帮助你理解和实践这一方法。
参考https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-reasoning-llms 1. STaR的原理
背景
传统的LLM生成方法通常依赖贪婪解码选择最高概率token或随机采样但这些方法可能无法生成逻辑严谨的推理步骤。STaR通过让模型自生成推理数据并进行监督微调Supervised Fine-Tuning优化其推理能力调整token的提议分布使其更倾向于推理过程。
核心思想
自生成推理数据模型首先生成推理步骤和答案。验证与修正 如果答案正确直接将推理步骤和答案加入训练数据集。如果答案错误提供正确答案作为“提示”让模型重新推理并生成正确过程。 监督微调用生成的数据集训练模型强化其推理行为。
目标
输入聚焦通过修改提议分布使模型更擅长生成推理相关token而非简单输出结果。自增强利用模型自身生成的数据无需大量人工标注。 2. STaR的工作流程
STaR的核心是一个循环过程包含以下步骤 生成推理步骤和答案 模型根据问题生成推理路径和最终答案。 验证答案 正确2a推理和答案正确进入步骤3b。错误2b答案错误进入步骤4b。 正确答案处理3b 将问题、推理步骤、答案组成三元组加入训练数据集。 错误答案修正4b 提供正确答案作为提示要求模型重新生成推理步骤。将修正后的推理加入训练数据集。 监督微调5 使用生成的三元组数据集对模型进行监督微调优化推理能力。
关键特点
合成数据STaR通过自生成数据创建训练样本类似于数据蒸馏。迭代改进多次循环生成和微调逐步提升模型性能。 3. 代码实现
以下是一个简化的STaR实现基于PyTorch。我们模拟一个数学推理任务如“2 3 ?”展示其工作流程。
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy# 超参数
vocab_size 10 # 词汇表大小0-9数字
embed_size 16
num_heads 2
hidden_size 32
num_layers 2
max_steps 3 # 最大推理步骤# 生成模型
class SimpleReasoner(nn.Module):def __init__(self):super(SimpleReasoner, self).__init__()self.embedding nn.Embedding(vocab_size, embed_size)self.transformer nn.TransformerDecoderLayer(embed_size, num_heads, hidden_size)self.output_layer nn.Linear(embed_size, vocab_size)def forward(self, x):x self.embedding(x)x self.transformer(x, x)return self.output_layer(x)def generate(self, prompt, max_len3, temperature1.0):seq prompt.copy()inputs torch.tensor([seq], dtypetorch.long).to(device)for _ in range(max_len - len(seq)):logits self.forward(inputs)[:, -1, :]probs F.softmax(logits / temperature, dim-1)next_token torch.multinomial(probs, 1).item()seq.append(next_token)inputs torch.tensor([seq], dtypetorch.long).to(device)return seqdef train_step(self, data, optimizer):self.train()optimizer.zero_grad()inputs torch.tensor([d[0] d[1][:-1] for d in data], dtypetorch.long).to(device)targets torch.tensor([d[1] for d in data], dtypetorch.long).to(device)logits self.forward(inputs)loss F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))loss.backward()optimizer.step()return loss.item()# STaR实现
class STaR:def __init__(self, model):self.model modelself.device next(model.parameters()).devicedef generate_reasoning(self, prompt, correct_answerNone):if correct_answer is None:return self.model.generate(prompt, max_steps)# 提供正确答案作为提示hint_prompt prompt [correct_answer]return self.model.generate(hint_prompt, max_steps)def verify_answer(self, sequence, correct_answer):return sequence[-1] correct_answerdef star_iteration(self, prompts, correct_answers, iterations3):training_data []for _ in range(iterations):new_model deepcopy(self.model) # 保存当前模型状态optimizer torch.optim.Adam(new_model.parameters(), lr0.001)for prompt, correct_answer in zip(prompts, correct_answers):# 步骤1生成推理步骤和答案sequence self.generate_reasoning(prompt)# 步骤2验证答案if self.verify_answer(sequence, correct_answer):# 步骤3b正确答案加入训练数据training_data.append((prompt, sequence))else:# 步骤4b错误答案提供提示重新生成corrected_sequence self.generate_reasoning(prompt, correct_answer)training_data.append((prompt, corrected_sequence))# 步骤5监督微调if training_data:loss new_model.train_step(training_data, optimizer)print(fIteration {_1}, Loss: {loss})self.model new_model # 更新模型return training_data# 初始化并运行
device torch.device(cuda if torch.cuda.is_available() else cpu)
model SimpleReasoner().to(device)
star STaR(model)# 示例数据
prompts [[2, 3]] # 2 3
correct_answers [5]# 执行STaR
training_data star.star_iteration(prompts, correct_answers, iterations3)
print(Generated training data:, training_data)# 测试优化后的模型
test_prompt [2, 3]
result model.generate(test_prompt)
print(fTest prompt: {test_prompt}, Generated result: {result})4. 代码解析
生成模型SimpleReasoner
generate根据提示生成推理序列模拟推理步骤和答案。train_step使用监督微调优化模型输入为问题推理步骤目标为完整序列。
STaR实现
generate_reasoning 无提示时自由生成推理。有提示时基于正确答案生成推理。 verify_answer检查生成序列的最后一个token是否正确。star_iteration 步骤1生成推理和答案。步骤2a/2b验证答案正确则直接记录错误则用提示修正。步骤3b/4b收集三元组问题、推理、答案。步骤5用生成的数据微调模型。
运行逻辑
每次迭代生成数据优化模型逐步提高推理能力。使用 deepcopy 保留模型状态确保迭代独立。 5. 运行结果示例
运行代码可能得到
Iteration 1, Loss: 2.305
Iteration 2, Loss: 2.287
Iteration 3, Loss: 2.251
Generated training data: [([2, 3], [2, 3, 5]), ([2, 3], [2, 3, 5]), ([2, 3], [2, 3, 5])]
Test prompt: [2, 3], Generated result: [2, 3, 5]未训练模型初始生成随机STaR通过微调逐渐倾向于正确答案 [2, 3, 5]。实际中需更多数据和迭代。 6. STaR的意义与改进
意义
自增强无需大量人工数据模型自生成训练样本。推理优化调整提议分布强化推理token的选择。数据蒸馏生成合成数据可用于其他模型训练。
改进方向
多样化提示增加问题类型如数学、自然语言问答。奖励函数引入PRM评估推理步骤质量而非仅验证答案。迭代控制动态调整迭代次数或数据筛选标准。预训练模型基于已有LLM如GPT实现提升初始性能。 7. 总结
STaR通过自生成推理数据和监督微调优化LLM的推理能力。其流程从生成到验证再到修正利用合成数据调整token分布是“修改提议分布”的典型方法。代码实现展示了从 [2, 3] 到 [2, 3, 5] 的优化过程体现了其核心思想。运行这段代码你可以体验STaR的自学过程。希望这篇博客对你理解和实践STaR有所帮助如需进一步优化欢迎讨论。
基于大型语言模型改进 STaR 方法以 LLaMA 3 或 Qwen 2.5 为例
在之前的STaRSelf-Taught Reasoner实现中我们使用了一个简化的模型来展示其工作原理。然而为了在实际任务中获得更好的推理能力可以基于Hugging FaceHF上的预训练大型语言模型LLM如 LLaMA 3 或 Qwen 2.5 进行改进。本文将以中文博客的形式结合改进方向多样化提示、奖励函数、迭代控制、预训练模型详细说明如何基于这些HF模型优化STaR并提供改进后的代码实现。 1. 改进背景与目标
原始实现局限
模型能力SimpleReasoner 未经过预训练生成随机且缺乏推理能力。提示单一仅支持简单数学任务。奖励简单仅验证答案未评估推理质量。静态迭代固定次数缺乏灵活性。
改进目标
预训练模型利用LLaMA 3或Qwen 2.5的强大语言理解能力。多样化提示支持数学和自然语言问答。奖励函数引入过程奖励模型PRM评估推理步骤。迭代控制动态调整迭代次数和数据筛选。 2. 改进方案
1. 基于预训练模型LLaMA 3 或 Qwen 2.5
选择原因 LLaMA 3高效、适合微调广泛用于研究。Qwen 2.5开源支持多语言推理能力强。 实现使用Hugging Face的 transformers 库加载预训练模型替换 SimpleReasoner。
2. 多样化提示
数学任务如“2 3 ?”。自然语言问答如“中国的首都是哪里”。方法扩展输入格式支持文本和符号混合。
3. 奖励函数引入PRM
目的评估推理步骤的逻辑性和完整性而非仅答案。实现使用一个小型预训练模型如BERT作为PRM评分推理质量。
4. 迭代控制
动态调整根据数据质量或损失收敛动态停止迭代。数据筛选仅保留高质量推理样本。 3. 改进后的代码实现
以下基于 Qwen 2.5也可替换为LLaMA 3的STaR实现展示改进后的完整流程。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from copy import deepcopy
import random# 超参数
max_steps 50 # 最大生成长度
device torch.device(cuda if torch.cuda.is_available() else cpu)# 初始化生成模型Qwen 2.5
model_name Qwen/Qwen2.5-7B-Instruct # 可替换为 meta-llama/Llama-3-8B
tokenizer AutoTokenizer.from_pretrained(model_name)
generator AutoModelForCausalLM.from_pretrained(model_name).to(device)# 初始化PRM使用BERT评估推理质量
prm_name bert-base-uncased
prm_tokenizer AutoTokenizer.from_pretrained(prm_name)
prm_model AutoModelForSequenceClassification.from_pretrained(prm_name, num_labels1).to(device)# STaR实现
class STaR:def __init__(self, generator, tokenizer, prm_model, prm_tokenizer):self.generator generatorself.tokenizer tokenizerself.prm_model prm_modelself.prm_tokenizer prm_tokenizerdef generate_reasoning(self, prompt, correct_answerNone, temperature0.7):生成推理步骤和答案if correct_answer is None:input_text f问题: {prompt}\n推理步骤和答案:else:input_text f问题: {prompt}\n正确答案: {correct_answer}\n请提供推理步骤:inputs self.tokenizer(input_text, return_tensorspt).to(device)outputs self.generator.generate(**inputs, max_lengthmax_steps, temperaturetemperature,do_sampleTrue, pad_token_idself.tokenizer.eos_token_id)return self.tokenizer.decode(outputs[0], skip_special_tokensTrue)def verify_answer(self, response, correct_answer):验证答案是否正确answer_part response.split(答案:)[-1].strip()return str(correct_answer) in answer_partdef evaluate_reasoning(self, response):使用PRM评估推理质量inputs self.prm_tokenizer(response, return_tensorspt, truncationTrue, max_length512).to(device)with torch.no_grad():score self.prm_model(**inputs).logits.item()return score # 返回正值表示推理质量def star_iteration(self, prompts, correct_answers, max_iterations5, min_loss0.1):training_data []model deepcopy(self.generator)optimizer torch.optim.AdamW(model.parameters(), lr5e-5)for iteration in range(max_iterations):new_data []total_loss 0.0for prompt, correct_answer in zip(prompts, correct_answers):# 步骤1生成推理和答案response self.generate_reasoning(prompt)# 步骤2验证答案if self.verify_answer(response, correct_answer):# 步骤3b正确答案检查推理质量score self.evaluate_reasoning(response)if score 0.5: # 筛选高质量推理new_data.append((prompt, response))else:# 步骤4b错误答案提供提示重新生成corrected_response self.generate_reasoning(prompt, correct_answer)score self.evaluate_reasoning(corrected_response)if score 0.5:new_data.append((prompt, corrected_response))# 步骤5监督微调if new_data:model.train()optimizer.zero_grad()inputs self.tokenizer([d[0] \n d[1] for d in new_data], return_tensorspt, paddingTrue, truncationTrue, max_lengthmax_steps).to(device)labels inputs[input_ids].clone()outputs model(**inputs, labelslabels)loss outputs.lossloss.backward()optimizer.step()total_loss loss.item()training_data.extend(new_data)print(fIteration {iteration1}, Loss: {total_loss / len(new_data) if new_data else 0})if total_loss / len(new_data) min_loss and new_data:breakself.generator modelreturn training_data# 示例数据
prompts [2 3等于多少,中国的首都是哪里
]
correct_answers [5, 北京]# 初始化STaR
star STaR(generator, tokenizer, prm_model, prm_tokenizer)# 执行STaR
training_data star.star_iteration(prompts, correct_answers)
print(Generated training data:, training_data)# 测试优化后的模型
for prompt in prompts:result star.generate_reasoning(prompt)print(fPrompt: {prompt}, Generated result: {result})4. 代码解析
1. 预训练模型Qwen 2.5
加载使用 AutoModelForCausalLM 加载Qwen 2.5替换简化的 SimpleReasoner。生成generate_reasoning 使用 model.generate 支持多样化提示生成文本而非token序列。优势Qwen 2.5 已具备语言理解能力初始生成更接近推理。
2. 多样化提示
输入格式 数学2 3等于多少\n推理步骤和答案:。问答中国的首都是哪里\n推理步骤和答案:。 输出支持自然语言生成完整句子如“推理2加3等于5答案5”。
3. 奖励函数PRM
实现使用BERT作为PRM评分推理文本的逻辑性。筛选score 0.5 保留高质量推理避免噪声数据。改进可训练BERT区分正确推理如“235”和错误推理如“2*35”。
4. 迭代控制
动态停止若损失低于 min_loss如0.1提前终止。数据筛选结合PRM分数确保训练数据质量。 5. 运行结果示例
运行代码可能得到
Iteration 1, Loss: 0.85
Iteration 2, Loss: 0.62
Iteration 3, Loss: 0.09
Generated training data: [(2 3等于多少, 问题: 2 3等于多少\n推理步骤和答案: 首先2加上3等于5。\n答案: 5),(中国的首都是哪里, 问题: 中国的首都是哪里\n推理步骤和答案: 中国是一个国家其首都是北京。\n答案: 北京)
]
Prompt: 2 3等于多少, Generated result: 问题: 2 3等于多少\n推理步骤和答案: 首先2加上3等于5。\n答案: 5
Prompt: 中国的首都是哪里, Generated result: 问题: 中国的首都是哪里\n推理步骤和答案: 中国是一个国家其首都是北京。\n答案: 北京结果Qwen 2.5初始生成已较合理微调后更倾向推理。 6. 基于LLM的改进优势
预训练能力
Qwen 2.5 或 LLaMA 3 自带语言理解和生成能力初始推理质量高于随机模型。STaR在此基础上进一步强化推理分布。
多样化支持
处理文本输入支持数学和问答任务扩展性强。
PRM增强
BERT作为PRM评估推理逻辑确保生成数据不仅是正确答案还包含合理步骤。
动态优化
损失收敛后停止节省计算资源。 7. 进一步优化建议
更大模型使用LLaMA 3-70B或Qwen 2.5-72B提升推理深度。混合奖励结合PRM和答案正确性ORM综合评分。数据蒸馏将STaR生成的数据用于其他模型如小规模LLM的训练。 8. 总结
基于Qwen 2.5的STaR改进利用预训练LLM的强大能力支持多样化提示通过PRM优化推理质量并动态控制迭代。代码展示了从数学到问答的推理生成体现了“修改提议分布”的核心思想。运行此代码你可以体验基于HF模型的STaR优化过程。希望这篇博客对你有所帮助如需调整或扩展欢迎讨论。
解析 STaR 中 star_iteration 的逐迭代训练设计
提出疑问为什么训练是每个iteration都要进行而不是将所有数据处理好后再进行一次训练下面详细解析这种逐迭代训练的设计动机分析其优劣势并探讨替代方案。 1. 逐迭代训练的背景
STaR的核心思想
STaRSelf-Taught Reasoner是一种自监督方法通过让模型生成推理数据并进行监督微调Supervised Fine-Tuning优化其推理能力。其流程本质上是一个迭代改进的过程
模型基于当前参数生成推理和答案。验证答案收集正确或修正后的数据。用生成的数据微调模型。重复上述步骤。
代码中的训练位置
每次迭代内训练在每个 for iteration in range(max_iterations) 循环中生成 new_data 后立即调用 model.train_step 进行微调。累计数据training_data.extend(new_data) 将每次迭代的数据加入总数据集但训练发生在每次迭代结束时。 2. 为什么每个Iteration都要训练
1. 动态优化模型分布
提议分布的修改 STaR的目标是调整模型的token提议分布使其倾向于生成推理相关的内容。每次迭代后模型参数通过微调更新下一次生成会基于更优的分布。 逐次改进 如果不训练模型在所有迭代中都使用初始参数生成的推理质量可能持续较差。每次训练后模型更可能生成正确的推理步骤逐步提升数据质量。
2. 自增强反馈循环
自生成数据 STaR依赖模型自身生成训练数据每次迭代的 new_data 是当前模型能力的反映。训练后模型能力提升下次生成的 new_data 更接近期望的推理模式。 反馈效应 类似强化学习每次迭代强化模型的推理行为形成正反馈。
3. 数据质量的逐步提高
初始数据可能较差 未训练模型生成的推理可能随机或错误如 [2, 3, 1]。第一次训练后模型学会部分正确模式如 [2, 3, 5]后续数据更优质。 避免积累噪声 若等到最后训练可能积累大量低质量数据影响微调效果。
4. 计算资源与时间优化
小批量训练 每次迭代只处理当前生成的 new_data如2个样本训练负担轻。若积累所有数据再训练可能需要更大批量或更多epoch增加内存和时间成本。 提前终止 if total_loss / len(new_data) min_loss: 允许在损失收敛时停止无需完成所有迭代。
代码中的体现
训练时机if new_data:model.train()optimizer.zero_grad()# ... 微调代码 ...optimizer.step()每次迭代立即训练确保模型实时更新。 3. 模拟过程
任务
prompts [2 3等于多少]correct_answers [5]。( max_iterations 3 \text{max\_iterations} 3 max_iterations3 )。
第一次迭代
生成response 问题: 2 3等于多少\n推理和答案: 2 * 3 6\n答案: 6。验证错误。修正corrected_response 问题: 2 3等于多少\n正确答案: 5\n推理: 2 3 5。数据new_data [(2 3等于多少, corrected_response)]。训练微调模型更新参数。
第二次迭代
生成response 问题: 2 3等于多少\n推理和答案: 2 3 5\n答案: 5因训练改进。验证正确score 0.5。数据new_data [(2 3等于多少, response)]。训练进一步强化正确推理。
第三次迭代
生成更稳定的正确推理。数据累计高质量样本。训练继续优化。
对比假设
若最后训练 第一次[2, 3, 6]。第二次[2, 3, 1]仍随机。第三次[2, 3, 4]。最后训练可能因数据混杂效果不佳。 4. 逐迭代训练的优势与劣势
优势
实时反馈每次迭代优化模型提升后续生成质量。数据质量递增避免积累低质量数据。灵活终止损失收敛时停止节省资源。
劣势
计算开销频繁训练增加总计算时间。模型不稳定小批量训练可能导致参数波动。实现复杂性需管理每次迭代的模型副本如 deepcopy。 5. 为何不等到所有数据处理好再训练
替代方案的问题
假设修改为收集所有数据后一次性训练
def star_iteration(self, prompts, correct_answers, max_iterations5):training_data []for _ in range(max_iterations):for prompt, correct_answer in zip(prompts, correct_answers):response self.generate_reasoning(prompt)if self.verify_answer(response, correct_answer):if self.evaluate_reasoning(response) 0.5:training_data.append((prompt, response))else:corrected_response self.generate_reasoning(prompt, correct_answer)if self.evaluate_reasoning(corrected_response) 0.5:training_data.append((prompt, corrected_response))# 一次性训练if training_data:model deepcopy(self.generator)optimizer torch.optim.AdamW(model.parameters(), lr5e-5)loss model.train_step(training_data, optimizer) # 假设支持多epochself.generator modelreturn training_data问题分析 数据质量不一致 所有迭代使用初始模型生成的 training_data 可能包含大量错误或低质量推理。无法利用中间训练的改进。 缺乏反馈 模型未在迭代中更新每次生成无进步可能浪费计算资源。 训练负担 一次性处理大量数据需更多epoch或更高计算资源可能超出现有硬件能力。 STaR目标偏离 STaR强调自增强循环逐迭代训练是其核心机制最后训练削弱了这一特性。 6. 改进建议
折中方案
批次训练每隔几轮迭代训练一次平衡反馈与效率if new_data and iteration % 2 0: # 每2轮训练一次model.train_step(new_data, optimizer)动态调整
自适应迭代根据数据质量如PRM分数调整训练频率。增量数据仅训练新增数据避免重复计算。 7. 总结
STaR中逐迭代训练的设计是为了
动态优化实时更新模型提升每次生成的质量。自增强形成反馈循环逐步强化推理能力。效率小批量训练结合提前终止适应资源限制。
相比之下所有数据处理后再训练可能导致数据质量低、缺乏反馈违背STaR的自适应优化目标。代码中的逐迭代训练是其核心优势的体现。
后记
2025年3月2日16点43分于上海在grok3大模型辅助下完成。