当前位置: 首页 > news >正文

TextCNN:文本卷积神经网络模型

目录

  • 什么是TextCNN
  • 定义TextCNN类
  • 初始化一个model实例
  • 输出model

什么是TextCNN

  • TextCNN(Text Convolutional Neural Network)是一种用于处理文本数据的卷积神经网(CNN)。通过在文本数据上应用卷积操作来提取局部特征,这些特征可以捕捉到文本中的局部模式,如n-gram(连续的n个单词或字符)。

定义TextCNN类

import torch.nn as nn# 它继承自 PyTorch 的 nn.Module
class TextCNN(nn.Module):# __init__:类的构造函数,初始化模型,包括嵌入层、卷积层、dropout层和全连接层def __init__(self, vocab_size, embed_dim, num_classes, num_filters, kernel_sizes):# 调用父类 nn.Module 的构造函数super(TextCNN, self).__init__()# 创建一个嵌入层,将词汇表中的每个单词映射到一个embed_dim 维的向量空间。vocab_size 是词汇表的大小self.embedding = nn.Embedding(vocab_size, embed_dim)# 创建一个卷积层列表,每个卷积层使用不同的 kernel_size。in_channels 是嵌入向量的维度,out_channels 是每个卷积核输出的特征数量self.convs = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k) for k in kernel_sizes])# 创建一个 Dropout 层,用于在训练过程中随机丢弃 50% 的节点,以减少过拟合self.dropout = nn.Dropout(0.5)# 创建一个全连接层,将卷积层的输出连接到最终的分类结果。# 输入特征的数量是卷积核数量乘以每个卷积核的输出特征数量,输出特征数量是分类类别的数量self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)# forward:定义模型的前向传播过程。# x:输入数据,通常是文本的整数序列。def forward(self, x):# 将输入数据通过嵌入层转换为嵌入向量x = self.embedding(x)  # 调整张量维度,以便卷积操作可以在嵌入向量的维度上进行x = x.transpose(1, 2) # 对每个卷积层应用激活函数 ReLU,生成特征图convs = [torch.relu(conv(x)) for conv in self.convs]  # 对每个卷积层的输出应用最大池化,以减少特征图的维度pooled = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs]  # 将所有卷积层的最大池化结果拼接在一起,形成一个单一的特征向量cat = torch.cat(pooled, 1) # 通过 Dropout 层和全连接层进行分类,输出最终的分类结果return self.fc(self.dropout(cat))

初始化一个model实例

vocab_size = 1000
embed_dim = 128
num_classes = 2
num_filters = 100
kernel_sizes = [3, 4, 5]model = TextCNN(vocab_size, embed_dim, num_classes, num_filters, kernel_sizes)

输出model

TextCNN((embedding): Embedding(8, 128)(convs): ModuleList((0): Conv1d(128, 100, kernel_size=(3,), stride=(1,))(1): Conv1d(128, 100, kernel_size=(4,), stride=(1,))(2): Conv1d(128, 100, kernel_size=(5,), stride=(1,)))(dropout): Dropout(p=0.5, inplace=False)(fc): Linear(in_features=300, out_features=2, bias=True)
)
  • Embedding(8, 128):这是一个嵌入层,它将词汇表中的每个单词映射到一个128维的向量空间。这里的8表示词汇表的大小(即输入序列中可能的最大单词索引),128表示每个单词将被映射到的向量维度。
  • convs: ModuleList[...]:这是一个包含多个一维卷积层(Conv1d)的模块列表。每个卷积层都用于提取文本数据的不同局部特征。
  • Conv1d(128, 100, kernel_size=(3,), stride=(1,)):每个卷积层有128个输入通道(与嵌入层的输出维度相同)和100个输出通道(即100个滤波器)。kernel_size=3表示每个滤波器的窗口大小为3个词。stride=1表示滤波器在文本序列上滑动的步长为1
  • Dropout(p=0.5, inplace=False):这是一个Dropout层,它在训练过程中随机丢弃50%的节点,以减少过拟合。inplace=False表示Dropout操作不会在原地修改输入张量。
  • fc: Linear(in_features=300, out_features=2, bias=True):这是一个全连接层,它将卷积层和Dropout层的输出转换为最终的分类结果。in_features=300表示全连接层的输入特征数量(这是由卷积层的数量和每个卷积层的输出特征数量决定的,即3个卷积层各100个特征)。out_features=2表示输出特征的数量,这通常与分类任务的类别数相对应(在这个例子中,可能是二分类问题)。bias=True表示全连接层的权重将包含偏置项。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • leetcode-581. 最短无序连续子数组
  • MySQL高级功能-窗口函数
  • Vue.js中computed的使用方法
  • 前端开发深入了解webpack
  • 【中秋月饼系列】2024年立体月饼新鲜出炉----python画月饼(1)附完整代码
  • 【Unity学习心得】如何使用Unity制作“饥荒”风格的俯视角2.5D游戏
  • 【随手笔记】
  • 安宝特案例 | AR如何大幅提升IC封装厂检测效率?
  • 安卓显示驱动
  • Unreal Engine——AI生成高精度的虚拟人物和环境(虚拟世界构建、电影场景生成)(一)
  • 喜报 | 知从科技荣获 “AutoSec 安全之星 - 优秀汽车软件供应链安全方案奖”
  • Linux创建虚拟磁盘并分区格式化
  • 剑灵服务端源码(c#版本+数据库+配套客户端+服务端)
  • 嵌入式学习——数据结构——顺序表
  • 20. 如何在MyBatis中处理多表关联查询?常见的实现方式有哪些?
  • [微信小程序] 使用ES6特性Class后出现编译异常
  • “寒冬”下的金三银四跳槽季来了,帮你客观分析一下局面
  • 5、React组件事件详解
  • Java 23种设计模式 之单例模式 7种实现方式
  • Nodejs和JavaWeb协助开发
  • Redux系列x:源码分析
  • Spark学习笔记之相关记录
  • SpiderData 2019年2月13日 DApp数据排行榜
  • SwizzleMethod 黑魔法
  • Vue官网教程学习过程中值得记录的一些事情
  • 不上全站https的网站你们就等着被恶心死吧
  • 初识MongoDB分片
  • 猴子数据域名防封接口降低小说被封的风险
  • 解析带emoji和链接的聊天系统消息
  • 码农张的Bug人生 - 初来乍到
  • 突破自己的技术思维
  • 问:在指定的JSON数据中(最外层是数组)根据指定条件拿到匹配到的结果
  • 用mpvue开发微信小程序
  • 06-01 点餐小程序前台界面搭建
  • MPAndroidChart 教程:Y轴 YAxis
  • 不要一棍子打翻所有黑盒模型,其实可以让它们发挥作用 ...
  • 机器人开始自主学习,是人类福祉,还是定时炸弹? ...
  • ​flutter 代码混淆
  • # 再次尝试 连接失败_无线WiFi无法连接到网络怎么办【解决方法】
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • #鸿蒙生态创新中心#揭幕仪式在深圳湾科技生态园举行
  • #控制台大学课堂点名问题_课堂随机点名
  • (007)XHTML文档之标题——h1~h6
  • (6) 深入探索Python-Pandas库的核心数据结构:DataFrame全面解析
  • (webRTC、RecordRTC):navigator.mediaDevices undefined
  • (windows2012共享文件夹和防火墙设置
  • (附源码)springboot 基于HTML5的个人网页的网站设计与实现 毕业设计 031623
  • (每日持续更新)jdk api之StringBufferInputStream基础、应用、实战
  • (转)真正的中国天气api接口xml,json(求加精) ...
  • (转)总结使用Unity 3D优化游戏运行性能的经验
  • .NET CF命令行调试器MDbg入门(四) Attaching to Processes
  • .net 设置默认首页
  • .net 逐行读取大文本文件_如何使用 Java 灵活读取 Excel 内容 ?
  • .NET/C#⾯试题汇总系列:集合、异常、泛型、LINQ、委托、EF!(完整版)
  • .NET6 命令行启动及发布单个Exe文件