dedecms 网站首页,网站编辑是个长期做的工作吗,有创意营销型网站建设,如何快速提升自己在Pytorch的2.2版本更新文档中#xff0c;官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。 今天抽空亲自试了下#xff0c;看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上#xff0c;下面是测试代码…在Pytorch的2.2版本更新文档中官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。 今天抽空亲自试了下看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上下面是测试代码一个是原始手写的Self-Attention的实现一个是使用Pytorch官方的scaled_dot_product_attention接口
import time
import torch
import torch.nn.functional as Fdef main():repeat 100device torch.device(cuda:0)dtype torch.float16query torch.rand(32, 8, 128, 64, dtypedtype, devicedevice)key torch.rand(32, 8, 128, 64, dtypedtype, devicedevice)value torch.rand(32, 8, 128, 64, dtypedtype, devicedevice)scale_factor 0.125ori_time_list []for _ in range(repeat):torch.cuda.synchronize(devicedevice)time_start time.perf_counter()# 原始Self-Attention实现res torch.softmax(query key.transpose(-2, -1) * scale_factor, dim-1) valuetorch.cuda.synchronize(devicedevice)time_end time.perf_counter()ori_time_list.append(time_end - time_start)fa_time_list []for _ in range(repeat):torch.cuda.synchronize(devicedevice)time_start time.perf_counter()with torch.backends.cuda.sdp_kernel(enable_mathFalse):# 使用Pytorch官方提供的FA实现res_fa F.scaled_dot_product_attention(query, key, value, scalescale_factor)torch.cuda.synchronize(devicedevice)time_end time.perf_counter()fa_time_list.append(time_end - time_start)diff (res - res_fa).abs().max()ratio [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]avg_ratio sum(ratio[1:]) / len(ratio[1:])print(fmax diff: {diff})print(favg speed up ratio: {avg_ratio})if __name__ __main__:main()
执行以上代码终端输出如下
max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118这里使用的设备是RTX4070跑了很多次发现确实加速2X左右看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大注意这里计算的Tensor都是FP16精度的如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速这个就很奇怪了官方文档里并没有说专门针对FP16精度优化的后面找了个A100的GPU试了下发现FP32也是有加速的。