怎么做才能发布网站,做微信网站,南宁网站设计多少钱,网页源代码查找快捷键【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理 1. TextCNN用于文本分类2. 代码实现 1. TextCNN用于文本分类
具体流程#xff1a;
2. 代码实现
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np… 【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理 1. TextCNN用于文本分类2. 代码实现 1. TextCNN用于文本分类
具体流程
2. 代码实现
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):配置参数def __init__(self, dataset, embedding):self.model_name TextCNNself.train_path dataset /data/train.txt # 训练集self.dev_path dataset /data/dev.txt # 验证集self.test_path dataset /data/test.txt # 测试集self.class_list [x.strip() for x in open(dataset /data/class.txt).readlines()] # 类别名单self.vocab_path dataset /data/vocab.pkl # 词表self.save_path dataset /saved_dict/ self.model_name .ckpt # 模型训练结果self.log_path dataset /log/ self.model_nameself.embedding_pretrained torch.tensor(np.load(dataset /data/ embedding)[embeddings].astype(float32))\if embedding ! random else None # 预训练词向量self.device torch.device(cuda if torch.cuda.is_available() else cpu) # 设备self.dropout 0.5 # 随机失活self.require_improvement 1000 # 若超过1000batch效果还没提升则提前结束训练self.num_classes len(self.class_list) # 类别数self.n_vocab 0 # 词表大小在运行时赋值self.num_epochs 20 # epoch数self.batch_size 128 # mini-batch大小self.pad_size 32 # 每句话处理成的长度(短填长切)self.learning_rate 1e-3 # 学习率self.embed self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300 # 字向量维度self.filter_sizes (2, 3, 4) # 卷积核尺寸self.num_filters 256 # 卷积核数量(channels数)Convolutional Neural Networks for Sentence Classificationclass Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding nn.Embedding.from_pretrained(config.embedding_pretrained, freezeFalse)else:self.embedding nn.Embedding(config.n_vocab, config.embed, padding_idxconfig.n_vocab - 1)self.convs nn.ModuleList([nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])self.dropout nn.Dropout(config.dropout)self.fc nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)def conv_and_pool(self, x, conv):x F.relu(conv(x)).squeeze(3)x F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):#print (x[0].shape)out self.embedding(x[0])out out.unsqueeze(1)out torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out self.dropout(out)out self.fc(out)return out
该代码对应上述的图像中的模块实现CNN用于处理文本数据