当前位置: 首页 > news >正文

CONV1D卷积神经网络运算过程(举例:n行3列➡n行6列)

文章目录

        • 背景
        • 计算过程
        • 计算过程图示
        • Conv1d() 代码举例
        • Linear() 的原理


背景

一维卷积的运算过程网上很多人说不清楚,示意图画的也不清楚。因此,本人针对一维卷积的过程,绘制了计算过程,以我的知识量解释一下 pytorch 中 Conv1d() 函数的机理。


计算过程

假设我们现在有 n 行,3列数据。n 行可以是 n 个点,也可以是 n 个样本数据。3列可以视为3列特征,即特征向量。我们想要通过 MLP 将其从3列升维度为6维度,就需要用 Conv1d() 函数。具体过程就是让每一行数据点乘一个卷积核,得到一个数,6个卷积核就是6个数,这样就把一个点的3列变成了6列。然后逐行遍历每个点,就可以得到新的得分矩阵。

备注: 从6列变成12列,就点乘12个卷积核。


计算过程图示

①、第1行数据参与卷积

②、第2行数据参与卷积

③、第n行数据参与卷积


Conv1d() 代码举例

我们以 PointNet 中分类的主干模型 (多层感知机,MLP) 来说,Conv1d(64, 128, 1) 其实就是用 128 个 64 行 1 列的卷积核和前面 n 行 64 列的矩阵逐行点积,升维到 128 列。

class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

Linear() 的原理

就是解 Y = X · A.T + b。其中的 A.T 是权重矩阵的转置矩阵,b 为偏置矩阵。nn.Linear(1024, 512) 就是将 n 行 1024 列的 X 矩阵降维到 n 行 512 列的矩阵。只要 A.T 为 1024 行 512 列的矩阵,和 X 点乘,就可以得到 n 行 512 列的矩阵,达到升维的目的。
①、Linear() 计算 (忽略偏置矩阵)
在这里插入图片描述

相关文章:

  • 数据结构c语言版第二版(严蔚敏)第一章练习
  • python练习Ⅱ--函数
  • 3D多模态成像市场现状及未来发展趋势分析
  • vscode 1.71变化与关注点(多配置预设/旧合并器回归等)
  • SQL面试题之区间合并问题
  • Linux用户和权限之一
  • 回溯法就是学不会2 —— 括号生成问题
  • ESP32 ESP-IDF TFT-LCD(ST7735 128x160) LVGL演示
  • 信息论学习笔记(二):离散无噪声系统
  • CentOS7启动SSH服务报错
  • 大咖说*计算讲谈社|商用车智能驾驶商业化实践
  • python笔记Ⅶ--函数返回值、作用域与命名空间、递归
  • 03 RocketMQ - Broker 源码分析
  • Java日志系列——规范化日志
  • 00前言说明-Qt自定义控件大全
  • Apache的80端口被占用以及访问时报错403
  • C# 免费离线人脸识别 2.0 Demo
  • ECMAScript入门(七)--Module语法
  • FineReport中如何实现自动滚屏效果
  • HTTP 简介
  • IDEA 插件开发入门教程
  • JDK 6和JDK 7中的substring()方法
  • JDK9: 集成 Jshell 和 Maven 项目.
  • leetcode98. Validate Binary Search Tree
  • spring-boot List转Page
  • 聊聊flink的BlobWriter
  • 深度学习中的信息论知识详解
  • 收藏好这篇,别再只说“数据劫持”了
  • 做一名精致的JavaScripter 01:JavaScript简介
  • MyCAT水平分库
  • #、%和$符号在OGNL表达式中经常出现
  • #我与Java虚拟机的故事#连载15:完整阅读的第一本技术书籍
  • (2.2w字)前端单元测试之Jest详解篇
  • (搬运以学习)flask 上下文的实现
  • (分享)自己整理的一些简单awk实用语句
  • (附源码)spring boot建达集团公司平台 毕业设计 141538
  • (附源码)ssm高校运动会管理系统 毕业设计 020419
  • (附源码)计算机毕业设计SSM保险客户管理系统
  • (十六)串口UART
  • (十五)使用Nexus创建Maven私服
  • (一)搭建springboot+vue前后端分离项目--前端vue搭建
  • (转)ABI是什么
  • (转)大型网站架构演变和知识体系
  • .h头文件 .lib动态链接库文件 .dll 动态链接库
  • .NET CORE 3.1 集成JWT鉴权和授权2
  • .NET Core 中的路径问题
  • .NET HttpWebRequest、WebClient、HttpClient
  • .net 无限分类
  • .NET开源项目介绍及资源推荐:数据持久层 (微软MVP写作)
  • .net知识和学习方法系列(二十一)CLR-枚举
  • @31省区市高考时间表来了,祝考试成功
  • @DateTimeFormat 和 @JsonFormat 注解详解
  • @Transactional注解下,循环取序列的值,但得到的值都相同的问题
  • [ vulhub漏洞复现篇 ] Django SQL注入漏洞复现 CVE-2021-35042
  • []FET-430SIM508 研究日志 11.3.31