微网站介绍,网站报404错误怎么解决办法,做拼团网站,北京南昌网站制作以Llama-2为例#xff0c;在生成模型中使用自定义StoppingCriteria 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言
在之前的文章中#xff0c;介绍了使用transformers模块创建的模型#xff0c;其generate方法的详细原理和使用方法#xff0c;文章链接#xff1a;
以be… 以Llama-2为例在生成模型中使用自定义StoppingCriteria 1. 前言2. 场景介绍3. 解决方法4. 结语 1. 前言
在之前的文章中介绍了使用transformers模块创建的模型其generate方法的详细原理和使用方法文章链接
以beam search为例详解transformers中generate方法上 以beam search为例详解transformers中generate方法下
其中提到了用户参与生成过程的两个关键组件logits_processor和stopping_criteria使用这两个类是用户控制生成过程的主要手段。其中logits_processor用来在生成过程中根据用户设置的指定规则强行修改当前step在词表空间上的概率分布而stopping_criteria根据用户所规定的规则来中止生成。
这两个组件在transformers模块中都有一些预设的类可以直接使用预设类的基本信息介绍可参考以beam search为例详解transformers中generate方法上。
本文将结合实际应用场景介绍用户如何根据自己的需求来设计并实现一个自定义的stopping_criteria来控制生成过程提前结束。
2. 场景介绍
这次介绍的场景是使用Llama-2的生成能力对一段新闻进行概括希望能够生成一句简短的话来概括新闻中发生的最核心的事情。
通过给定对话背景结合历史样例的方式希望Llama-2能够输出期望的结果。
对话的prompt构造方法可以参考之前的内容NLP实践——Llama-2 多轮对话prompt构建。
然而即便是采用了in-context learning的方式Llama-2生成的结果仍然过于冗长。
例如对于这样一篇新闻
text , Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ...
# 后边忽略若干内容模型生成的结果为
Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This years exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容可以看出并不是模型生成的结果不好但是它太啰嗦了而对于我的需求而言模型只需要输出其中的第一句话就足够了。
这时候可能有人就会觉得“那我分句然后把第一句话保留下来不就好了”
——这样做虽然也可以达成效果但是这个生成过程时间和算力已经被消耗了。
所以需要采取方法让模型在生成到第一个句号的时候就停止生成返回结果。于是就需要用到今天的主角——Stopping Criteria。
3. 解决方法
transformers模块中内置了几个默认的stopping criteria然而在很多情况下它们并不能满足需求这时就需要创建自定义的stopping criteria。
首先需要引用基类
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings其中
StoppingCriteriaList是一个容器需要将所有的criteria都添加到其中generate时传入的是这个容器StoppingCriteria是基础类自定义的criteria需要继承这个基础类。
接下来就实现一个criteria效果是遇到指定的token时就停止生成
class StopAtSpecificTokenCriteria(StoppingCriteria):当生成出第一个指定token时立即停止生成---------------ver: 2023-08-02by: changhongyudef __init__(self, token_id_list: List[int] None)::param token_id_list: 停止生成的指定token的id的列表self.token_id_list token_id_listadd_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) - bool:# return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list# 储存scores会额外占用资源所以直接用input_ids进行判断return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list那么如果希望遇到句号就停止生成那就用句号对应的token_id去实例化一个这样的stopping criteria并将它添加到容器中
# Llama-2的词表中英文句号的id是29889
stopping_criteria StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list[29889]))然后在生成的时候假如原本的生成指令是
model.generate(**inputs)那么再把stopping criteria作为参数传入进去就可以发挥效果了
model.generate(stopping_criteriastopping_criteria, **inputs)4. 结语
Stopping Criteria用于在每一个step的生成结束时判断生成过程是否要结束是用户控制生成过程的有效手段其发挥作用的方式也比较直接实现自定义criteria也并不复杂只需要确保该类的调用方法返回值是bool值并覆盖全部情况即可。
Logits Processor是用户控制生成的另一个有效工具在接下来的博客中还将介绍自定义logits processor是如何使用的欢迎感兴趣的同学继续关注。