网站空间维护,个人怎么注册一个品牌,长春网站优化平台,苏州网页制作简介 基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback#xff0c;RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步#xff0c;它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而#xff0c;它也给 NLP 引入了一些 RL 相关…简介 基于人类反馈的强化学习 (Reinforcement Learning from Human FeedbackRLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂涉及到许多复杂的组件而这些组件本身在训练过程中又是动态变化的因此把它们料理好并不容易。
Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标这一做法大大简化了 LLM 的提纯过程。
本文介绍了直接偏好优化 (Direct Preference OptimizationDPO) 法该方法现已集成至 TRL 库 中。同时我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型nbsp;stack-exchange preferencenbsp;数据集中包含了各个nbsp;stack-exchangenbsp;门户上的各种问题及其排序后的回答。
DPO 与 PPO
在通过 RL 优化人类衍生偏好时一直以来的传统做法是使用一个辅助奖励模型来微调目标模型以通过 RL 机制最大化目标模型所能获得的奖励。直观上我们使用奖励模型向待优化模型提供反馈以促使它多生成高奖励输出少生成低奖励输出。同时我们使用冻结的参考模型来确保输出偏差不会太大且继续保持输出的多样性。这通常需要在目标函数设计时除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项这样做有助于防止模型学习作弊或钻营奖励模型。
DPO 绕过了建模奖励函数这一步这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失从而直接在偏好数据上优化语言模型因此DPO 从寻找最小化 RLHF 损失的最佳方案开始通过改变参量的方式推导出一个nbsp;仅需nbsp;参考模型的损失
有了它我们可以直接优化该似然目标而不需要奖励模型或繁琐的强化学习优化过程。
如何使用 TRL 进行训练
如前所述一个典型的 RLHF 流水线通常包含以下几个环节:
有监督微调 (supervised fine-tuningSFT)用偏好标签标注数据基于偏好数据训练奖励模型RL 优化
TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4)直接根据标注好的偏好数据优化 DPO 目标。
使用 DPO我们仍然需要执行环节 1但我们仅需在 TRL 中向nbsp;DPOTrainernbsp;提供环节 2 准备好的偏好数据而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式它是一个含有以下 3 个键的字典:
promptnbsp;: 即推理时输入给模型的提示chosennbsp;: 即针对给定提示的较优回答rejectednbsp;: nbsp;即针对给定提示的较劣回答或非给定提示的回答
例如对于nbsp;stack-exchange preferencenbsp;数据集我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:
defnbsp;return_prompt_and_responses(samples)nbsp;-nbsp;Dict[str,nbsp;str,nbsp;str]:nbsp;nbsp;nbsp;nbsp;returnnbsp;{nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;prompt:nbsp;[nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;Question:nbsp;nbsp;nbsp;questionnbsp;nbsp;\n\nAnswer:nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;fornbsp;questionnbsp;innbsp;samples[question]nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;],nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;chosen:nbsp;samples[response_j],nbsp;#nbsp;ratednbsp;betternbsp;thannbsp;knbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;nbsp;rejected:nbsp;samples[response_k],nbsp;#nbsp;ratednbsp;worsenbsp;thannbsp;jnbsp;nbsp;nbsp;nbsp;}datasetnbsp;nbsp;load_dataset(nbsp;nbsp;nbsp;nbsp;lvwerra/stack-exchange-paired,nbsp;nbsp;nbsp;nbsp;splittrain,nbsp;nbsp;nbsp;nbsp;data_dirdata/rl)original_columnsnbsp;nbsp;dataset.column_namesdataset.map(nbsp;nbsp;nbsp;nbsp;return_prompt_and_responses,nbsp;nbsp;nbsp;nbsp;batchedTrue,nbsp;nbsp;nbsp;nbsp;remove_columnsoriginal_columns)
一旦有了排序数据集DPO 损失其实本质上就是一种有监督损失其经由参考模型获得隐式奖励。因此从上层来看DPOTrainernbsp;需要我们输入待优化的基础模型以及参考模型:
dpo_trainernbsp;nbsp;DPOTrainer(nbsp;nbsp;nbsp;nbsp;model,nbsp;#nbsp;经nbsp;SFTnbsp;的基础模型nbsp;nbsp;nbsp;nbsp;model_ref,nbsp;#nbsp;一般为经nbsp;SFTnbsp;的基础模型的一个拷贝nbsp;nbsp;nbsp;nbsp;beta0.1,nbsp;#nbsp;DPOnbsp;的温度超参nbsp;nbsp;nbsp;nbsp;train_datasetdataset,nbsp;#nbsp;上文准备好的数据集nbsp;nbsp;nbsp;nbsp;tokenizertokenizer,nbsp;#nbsp;分词器nbsp;nbsp;nbsp;nbsp;argstraining_args,nbsp;#nbsp;训练参数如:nbsp;batchnbsp;size,nbsp;学习率等)
其中超参nbsp;betanbsp;是 DPO 损失的温度通常在nbsp;0.1nbsp;到nbsp;0.5nbsp;之间。它控制了我们对参考模型的关注程度betanbsp;越小我们就越忽略参考模型。对训练器初始化后我们就可以简单调用以下方法使用给定的nbsp;training_argsnbsp;在给定数据集上进行训练了:
dpo_trainer.train()
基于 Llama v2 进行实验
在 TRL 中实现 DPO 训练器的好处是人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。
有监督微调
如上文所述我们先用 TRL 的nbsp;SFTTrainernbsp;在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:
#nbsp;loadnbsp;thenbsp;basenbsp;modelnbsp;innbsp;4-bitnbsp;quantizationbnb_confignbsp;nbsp;BitsAndBytesConfig(nbsp;nbsp;nbsp;nbsp;load_in_4bitTrue,nbsp;nbsp;nbsp;nbsp;bnb_4bit_quant_typenf4,nbsp;nbsp;nbsp;nbsp;bnb_4bit_compute_dtypetorch.bfloat16,)base_modelnbsp;nbsp;AutoModelForCausalLM.from_pretrained(nbsp;nbsp;nbsp;nbsp;script_args.model_name,nbsp;#nbsp;meta-llama/Llama-2-7b-hfnbsp;nbsp;nbsp;nbsp;quantization_configbnb_config,nbsp;nbsp;nbsp;nbsp;device_map{:nbsp;0},nbsp;nbsp;nbsp;nbsp;trust_remote_codeTrue,nbsp;nbsp;nbsp;nbsp;use_auth_tokenTrue,)base_model.config.use_cachenbsp;nbsp;False#nbsp;addnbsp;LoRAnbsp;layersnbsp;onnbsp;topnbsp;ofnbsp;thenbsp;quantizednbsp;basenbsp;modelpeft_confignbsp;nbsp;LoraConfig(nbsp;nbsp;nbsp;nbsp;rscript_args.lora_r,nbsp;nbsp;nbsp;nbsp;lora_alphascript_args.lora_alpha,nbsp;nbsp;nbsp;nbsp;lora_dropoutscript_args.lora_dropout,nbsp;nbsp;nbsp;nbsp;target_modules[q_proj,nbsp;v_proj],nbsp;nbsp;nbsp;nbsp;biasnone,nbsp;nbsp;nbsp;nbsp;task_typeCAUSAL_LM,)...trainernbsp;nbsp;SFTTrainer(nbsp;nbsp;nbsp;nbsp;modelbase_model,nbsp;nbsp;nbsp;nbsp;train_datasettrain_dataset,nbsp;nbsp;nbsp;nbsp;eval_dataseteval_dataset,nbsp;nbsp;nbsp;nbsp;peft_configpeft_config,nbsp;nbsp;nbsp;nbsp;packingTrue,nbsp;nbsp;nbsp;nbsp;max_seq_lengthNone,nbsp;nbsp;nbsp;nbsp;tokenizertokenizer,nbsp;nbsp;nbsp;nbsp;argstraining_args,nbsp;#nbsp;HFnbsp;Trainernbsp;arguments)trainer.train()
DPO 训练
SFT 结束后我们保存好生成的模型。接着我们继续进行 DPO 训练我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型并在上文生成的nbsp;stack-exchange preferencenbsp;数据上以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调因此我们使用 Peft 的nbsp;AutoPeftModelForCausalLMnbsp;函数加载模型:
modelnbsp;nbsp;AutoPeftModelForCausalLM.from_pretrained(nbsp;nbsp;nbsp;nbsp;script_args.model_name_or_path,nbsp;#nbsp;locationnbsp;ofnbsp;savednbsp;SFTnbsp;modelnbsp;nbsp;nbsp;nbsp;low_cpu_mem_usageTrue,nbsp;nbsp;nbsp;nbsp;torch_dtypetorch.float16,nbsp;nbsp;nbsp;nbsp;load_in_4bitTrue,nbsp;nbsp;nbsp;nbsp;is_trainableTrue,)model_refnbsp;nbsp;AutoPeftModelForCausalLM.from_pretrained(nbsp;nbsp;nbsp;nbsp;script_args.model_name_or_path,nbsp;#nbsp;samenbsp;modelnbsp;asnbsp;thenbsp;mainnbsp;onenbsp;nbsp;nbsp;nbsp;low_cpu_mem_usageTrue,nbsp;nbsp;nbsp;nbsp;torch_dtypetorch.float16,nbsp;nbsp;nbsp;nbsp;load_in_4bitTrue,)...dpo_trainernbsp;nbsp;DPOTrainer(nbsp;nbsp;nbsp;nbsp;model,nbsp;nbsp;nbsp;nbsp;model_ref,nbsp;nbsp;nbsp;nbsp;argstraining_args,nbsp;nbsp;nbsp;nbsp;betascript_args.beta,nbsp;nbsp;nbsp;nbsp;train_datasettrain_dataset,nbsp;nbsp;nbsp;nbsp;eval_dataseteval_dataset,nbsp;nbsp;nbsp;nbsp;tokenizertokenizer,nbsp;nbsp;nbsp;nbsp;peft_configpeft_config,)dpo_trainer.train()dpo_trainer.save_model()
可以看出我们以 4 比特的方式加载模型然后通过nbsp;peft_confignbsp;参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度并报告一些关键指标例如可以选择通过 WandB 记录并显示隐式奖励。最后我们可以将训练好的模型推送到 HuggingFace Hub。
总结
SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到训好的已合并模型也已上传至 HF Hub (见 此处)。