当前位置: 首页 > news >正文

【Pytorch】学习记录分享10——TextCNN用于文本分类处理

【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理

      • 1. TextCNN用于文本分类
      • 2. 代码实现

1. TextCNN用于文本分类

具体流程:
在这里插入图片描述
在这里插入图片描述

2. 代码实现

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextCNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集self.class_list = [x.strip() for x in open(dataset + '/data/class.txt').readlines()]                                # 类别名单self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸self.num_filters = 256                                          # 卷积核数量(channels数)'''Convolutional Neural Networks for Sentence Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.convs = nn.ModuleList([nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])self.dropout = nn.Dropout(config.dropout)self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)def conv_and_pool(self, x, conv):x = F.relu(conv(x)).squeeze(3)x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):#print (x[0].shape)out = self.embedding(x[0])out = out.unsqueeze(1)out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out = self.dropout(out)out = self.fc(out)return out

该代码对应上述的图像中的模块实现,CNN用于处理文本数据

相关文章:

  • Linux 修改主机名称并通过主机名称访问服务器
  • 小心JDK20 ZipOutputStream
  • 计算机网络(6):应用层
  • 桌面天气预报软件 Weather Widget free mac特点介绍
  • BRC20 技术分析
  • element-ui table height 属性导致界面卡死
  • Vue 3.4 发布
  • 关于“Python”的核心知识点整理大全61
  • SpringBoot整合Elasticsearch报错
  • 机器学习--回归算法
  • Materail Design 进阶(十一)——MaterialButton使用
  • 【Delphi 基础知识 13】匿名方法的使用
  • DashScope灵积模型服务 java testcase - 特色功能 模型监督学习
  • Java解决峰与谷问题
  • 编程笔记 html5cssjs 027 HTML输入属性(1/2)
  • 《Javascript高级程序设计 (第三版)》第五章 引用类型
  • Akka系列(七):Actor持久化之Akka persistence
  • css系列之关于字体的事
  • exports和module.exports
  • iOS仿今日头条、壁纸应用、筛选分类、三方微博、颜色填充等源码
  • Laravel5.4 Queues队列学习
  • learning koa2.x
  • Linux学习笔记6-使用fdisk进行磁盘管理
  • Map集合、散列表、红黑树介绍
  • SpiderData 2019年2月13日 DApp数据排行榜
  • Spring核心 Bean的高级装配
  • Synchronized 关键字使用、底层原理、JDK1.6 之后的底层优化以及 和ReenTrantLock 的对比...
  • Vim Clutch | 面向脚踏板编程……
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 等保2.0 | 几维安全发布等保检测、等保加固专版 加速企业等保合规
  • 可能是历史上最全的CC0版权可以免费商用的图片网站
  • 适配iPhoneX、iPhoneXs、iPhoneXs Max、iPhoneXr 屏幕尺寸及安全区域
  • 一个SAP顾问在美国的这些年
  • 译自由幺半群
  • !! 2.对十份论文和报告中的关于OpenCV和Android NDK开发的总结
  • (1/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (day 2)JavaScript学习笔记(基础之变量、常量和注释)
  • (分布式缓存)Redis持久化
  • (图)IntelliTrace Tools 跟踪云端程序
  • (一)eclipse Dynamic web project 工程目录以及文件路径问题
  • *ST京蓝入股力合节能 着力绿色智慧城市服务
  • .NET Core IdentityServer4实战-开篇介绍与规划
  • .net core webapi 大文件上传到wwwroot文件夹
  • .NET3.5下用Lambda简化跨线程访问窗体控件,避免繁复的delegate,Invoke(转)
  • .NET使用存储过程实现对数据库的增删改查
  • .project文件
  • .vimrc php,修改home目录下的.vimrc文件,vim配置php高亮显示
  • ::
  • @AliasFor注解
  • @Import注解详解
  • [ Linux 长征路第二篇] 基本指令head,tail,date,cal,find,grep,zip,tar,bc,unname
  • [ vulhub漏洞复现篇 ] Apache APISIX 默认密钥漏洞 CVE-2020-13945
  • [ vulhub漏洞复现篇 ] GhostScript 沙箱绕过(任意命令执行)漏洞CVE-2019-6116
  • [ vulhub漏洞复现篇 ] struts2远程代码执行漏洞 S2-005 (CVE-2010-1870)
  • [【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器