五矿瑞和上海建设有限公司网站,一元夺宝网站开发,网站 创意 方案,深圳技术支持 骏域网站建设一、将caffe模型的权重转成dict格式
caffe库的编译可以参考我之前写的一篇博客#xff1a;ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客
安装好后使用以下脚本便可将caffe模型的参数名和参数…一、将caffe模型的权重转成dict格式
caffe库的编译可以参考我之前写的一篇博客ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客
安装好后使用以下脚本便可将caffe模型的参数名和参数保存成dict
import pickle as pkl
import caffeMODEL_FILE xxx.prototxt
PRETRAIN_FILE xxx.caffemodelif __name__ __main__:net caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)name_weights {}for param_name in net.params.keys():name_weights[param_name] {}layer_params net.params[param_name]if len(layer_params) 1:weight layer_params[0].dataname_weights[param_name][weight] weightprint(%s:\n\t%s (weight) % (param_name, weight.shape))elif len(layer_params) 2:# weightweight layer_params[0].dataname_weights[param_name][weight] weight# biasbias layer_params[1].dataname_weights[param_name][bias] biasprint(%s:\n\t%s (weight) % (param_name, weight.shape))print(\t%s (bias) % str(bias.shape))elif len(layer_params) 3:# BN: running_mean, running_var, scale factorrunning_mean layer_params[0].data # running_meanname_weights[param_name][running_mean] running_mean / layer_params[2].datarunning_var layer_params[1].data # running_varname_weights[param_name][running_var] running_var/layer_params[2].dataprint(%s:\n\t%s (running_var) % (param_name, running_var.shape),)print(\t%s (running_mean) % str(running_mean.shape))else:raise RuntimeError(error\n)# save weightwith open(weights.pkl, wb) as f:pkl.dump(name_weights, f, protocol2)
二、pytorch模型加载dict格式的权重
这里有两个思路一是根据权重名来匹配二是根据权重的shape来匹配但第二个方法有个问题就是如果网络中有两个以上shape一样的层的话那么根据权重的shape来匹配就会出错。下面分别介绍一下以上两个思路
1、根据权重名匹配
这个方法比较繁琐要求pytorch模型的参数名要与caffe模型的保持一致如果不一致则需要自己写个dict进行映射。具体操作如下
import pickle as pkl
import torch
import copymodel xxx
model1 copy.deepcopy(model)state_dict {}
with open(weights.pkl, rb) as wp: # weights.pkl: 步骤一中生成的dictname_weights pkl.load(wp, encodingiso-8859-1)for key, value in name_weights.items():for k, v in value.items():state_dict[key . k] torch.from_numpy(v)
model1.load_state_dict(state_dict, strictTrue)
另一种实现是直接对pytorch模型的参数赋值代码如下
import pickle as pkl
import torch
import copymodel xxx
model2 copy.deepcopy(model)with open(weights.pkl, rb) as wp:name_weights pkl.load(wp, encodingiso-8859-1)for name, param in model2.named_parameters():for key, value in name_weights.items():if name.split(.)[0] key:for k, v in value.items():if name.split(.)[1] k:param.data torch.from_numpy(v)
2、根据权重shape匹配
import pickle as pkl
import torch
import copymodel LightCNN_ir_eye()
model3 copy.deepcopy(model)with open(weights.pkl, rb) as wp:name_weights pkl.load(wp, encodingiso-8859-1)for name, param in model3.named_parameters():for key, value in name_weights.items():for k, v in value.items():v torch.from_numpy(v)if param.data.shape v.shape:if name key . k: # 防止多个权重shape一致导致的错误param.data v
3、检查以上模型初始化方法是否正确
import cv2
import numpy as np
import torchimg cv2.imread(xxx.jpg)
img cv2.resize(img, (width, height))
img np.tranpose(img, (2,0,1))
img np.expand_dims(img, axis0)out1 model1(torch.from_numpy(img).float())
out2 model2(torch.from_numpy(img).float())
out3 model3(torch.from_numpy(img).float())print(out1)
print(out2)
print(out3)
for i in range(len(out1)):print(out1[i] out2[i])print(out1[i] out3[i])