北京大学网站开发的需求分析,零基础网页设计制作培训,官网seo优化,模板网站区别训练好一个模型之后#xff0c;我们往往要对其进行保存#xff0c;除非下次用时想再次训练一遍。
下面以一个简单的回归任务来详细讲解模型的保存和加载。 来看这样一组数据#xff1a;
xtorch.linspace(-1,1,50)xx.view(50,1)yx.pow(2)0.3*torch.rand(50).view(50,1)
画…训练好一个模型之后我们往往要对其进行保存除非下次用时想再次训练一遍。
下面以一个简单的回归任务来详细讲解模型的保存和加载。 来看这样一组数据
xtorch.linspace(-1,1,50)xx.view(50,1)yx.pow(2)0.3*torch.rand(50).view(50,1)
画图
plt.scatter(x.numpy(),y.numpy()) 很显然x与y基本呈二次函数关系那么接下来我们就来拟合整个函数。
import torchimport matplotlib.pyplot as pltimport torch.nn as nnimport torch.optim as optimxtorch.linspace(-1,1,50)xx.view(50,1)yx.pow(2)0.3*torch.rand(50).view(50,1)net1nn.Sequential(nn.Linear(1,10), nn.ReLU(), nn.Linear(10,1))criterionnn.MSELoss()optimizeroptim.SGD(net1.parameters(),lr0.2)#训练模型for i in range(1000): prednet1(x) losscriterion(pred,y) optimizer.zero_grad() loss.backward() optimizer.step()
#测试模型net1.eval()with torch.no_grad(): y1net1(x) plt.plot(x.numpy(),y1.numpy(),r-) plt.scatter(x.numpy(),y.numpy()) 结果似乎不错
这里我们得到了一个网络net1它可以被当作一个二次函数用于描述之前的xy数据的关系。
得到这个网络后我们想保存它主要有两种方式
1保存整个网络包括训练后的各个层的参数
#保存整个网络包括训练后的各个层的参数torch.save(net1,net1weight.pkl)
2只保存训练好的网络的参数速度更快
#只保存训练好的网络的参数速度更快torch.save(net1.state_dict(),net1_params.pkl) 假设我们按第一种方式保存那么下次想要使用次网络时需要这样做
networktorch.load(net1weight.pkl)
#测试模型network.eval()with torch.no_grad(): y1network(x) plt.plot(x.numpy(),y1.numpy(),b-) plt.scatter(x.numpy(),y.numpy()) 假设我们按第二种方式保存那么下次想要使用次网络时需要这样做
networknn.Sequential(nn.Linear(1,10), nn.ReLU(), nn.Linear(10,1))network.load_state_dict(torch.load(net1_params.pkl))
#测试模型network.eval()with torch.no_grad(): y1network(x) plt.plot(x.numpy(),y1.numpy(),g-) plt.scatter(x.numpy(),y.numpy()) 可以看出第二次首先需要构造出一个一模一样的模型接着再导入参数即可。当然这只是个简单的回归模型其它模型保存与加载同样如此。 总结一下
模型保存与导入有两种方式
方式一
#模型保存torch.save(net1,net1weight.pkl)#模型导入networktorch.load(net1weight.pkl)
方式二
#模型保存torch.save(net1.state_dict(),net1_params.pkl)#模型导入network.load_state_dict(torch.load(net1_params.pkl))