那个网站做视频能挣钱,wordpress媒体库文件,有人知道做网站吗?,网页编辑用户信息原理文章目录 一、修改1.方法2.代码 二、保存和读取1.方法2.代码#xff08;1#xff09;保存#xff08;2#xff09;加载 3.陷阱 一、修改
1.方法
add_module(name: str, module: Module) - None
name 是要添加的子模块的名称。 module 是要添加的子模块。 调用 add_m… 文章目录 一、修改1.方法2.代码 二、保存和读取1.方法2.代码1保存2加载 3.陷阱 一、修改
1.方法
add_module(name: str, module: Module) - None
name 是要添加的子模块的名称。 module 是要添加的子模块。 调用 add_module 方法会向当前模块中添加一个子模块并使用指定的名称进行标识。
2.代码
import torchvision
from torch import nn# 实例化一个未经过预训练的 VGG16 模型
vgg16_false torchvision.models.vgg16(pretrainedFalse)# 实例化一个经过预训练的 VGG16 模型
vgg16_true torchvision.models.vgg16(pretrainedTrue)print(ok)# 输出经过预训练的 VGG16 模型及修改后的模型
print(vgg16_true)
vgg16_true.classifier.add_module(add_linear, nn.Linear(1000, 10))
print(vgg16_true)# 输出未经过预训练的 VGG16 模型及修改后的模型
print(vgg16_false)
vgg16_false.classifier[6] nn.Linear(4096, 10)
print(vgg16_false)修改前的vgg16_true 修改后的vgg16_true 修改前的vgg16_true 修改后的vgg16_true 二、保存和读取
1.方法
保存 torch.save(要保存的模型“文件路径”)
加载 torch.load(“文件路径”)
2.代码
1保存
import torch
import torchvisionvgg16 torchvision.models.vgg16(pretrainedFalse)# 保存方式1模型结构模型参数
torch.save(vgg16, vgg16_module1.pth)# 保存方式2模型参数官方推荐
torch.save(vgg16.state_dict(), vgg16_module2.pth)2加载
import torch
import torchvision# 方式1 加载模型
module1 torch.load(vgg16_module1.pth)
print(module1)#
module2 torch.load(vgg16_module2.pth)
print(module2)# 方式2 加载模型
vgg16 torchvision.models.vgg16(pretrainedFalse)
vgg16.load_state_dict(torch.load(vgg16_module2.pth))
print(vgg16)运行加载的代码后打印结果如下
module1 module2: vgg16 可以看到第二种方式保存的数据加载后是向量形式需要通过别的方法加载为模型
3.陷阱
第一种方式加载在某些条件下可能会报错
例如
假设自定义一个神经网络保存
import torch
import torchvision
from torch import nn# 陷阱
class Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 nn.Conv2d(3, 64, kernel_size3)def forward(self,x):x self.conv1(x)return xguodong Guodong()
torch.save(guodong,guodong_method1.pth)
在另一个文件中加载
import torch# 陷阱
module torch.load(guodong_method1.pth)
print(module)就会报错
AttributeError: Can’t get attribute ‘Guodong’ on module ‘main’ from ‘E:\deepLearning\Pycharm\pytroch_project\theFirstFile\module_load.py’
解决办法
1把Guodong类放在这个文件里
import torch
from torch import nn
import torchvisionclass Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 nn.Conv2d(3, 64, kernel_size3)def forward(self,x):x self.conv1(x)return x# 陷阱
module torch.load(guodong_method1.pth)
print(module)
2from module_save import *
module_save是保存自定义模型的文件
from module_save import *# 陷阱
module torch.load(guodong_method1.pth)
print(module)