当前位置: 首页 > news >正文

转置卷积详解(原理+实验)

文章目录

  • 转置卷积详解
    • 转置卷积理论📝📝📝
    • 转置卷积实验📝📝📝

转置卷积详解

转置卷积理论📝📝📝

​   这里我将通过一个小例子来讲述转置卷积的步骤,并通过代码来验证这个步骤的正确性。首先我们先来看看转置卷积的步骤,如下:

  • 在输入特征图元素间填充s-1行、列0(其中s表示转置卷积的步距,注意这里的步长s和卷积操作中的有些不同)
  • 在输入特征图四周填充k-p-1行、列0(其中k表示转置卷积kernel_size大小,p为转置卷积的padding,注意这里的padding和卷积操作中的有些不同)
  • 将卷积核参数上下、左右翻转
  • 做正常卷积运算(padding=0,s=1)

​    是不是还是懵逼的状态呢,不用急,现在就通过一个例子来讲述这个过程。首先我们假设输入特征图的尺寸为2*2大小,s=2,k=3,p=0,如下图所示:

​   第一步我们需要在特征图元素间填充s-1=1 行、列 0 (即填充1行0,1列0),变换后特征图如下:

​   第二步我们需要在输入特征图四周填充k-p-1=2 行、列0(即填充2行0,2列0),变换后特征图如下:

​   第三步我们需要将卷积核上下、左右翻转,得到新的卷积核【卷积核尺寸为k=3】,卷积核变化过程如下:

image-20220722171034209

​   最后一步,我们做正常的卷积即可【注:拿第二步得到的特征图和第三步翻转后得到的卷积核做正常卷积】,结果如下:

image-20220722180107129

​   至此我们就从完成了转置卷积,从一个2*2大小的特征图变成了一个5*5大小的特征图,如下图所示(忽略了中间步骤):

image-20220722171802281

​   为了让大家更直观的感受转置卷积的过程,我从Github上down了一个此过程动态图供大家参考,如下:【注:需要动态图点击☞☞☞自取】

在这里插入图片描述

​   通过上文的讲述,相信你已经对转置卷积的步骤比较清楚了。这时候你就可以试试图1中结构,看看应用上述的方法能否得到对应的结构。需要注意的是,在第一次转置卷积时,使用的参数k=4,s=1,p=0,后面的参数都为k=4,s=2,p=1,如下图所示:

image-20220722185445223

​   如果你按照我的步骤试了试,可能会发出一些吐槽,这也太麻烦了,我只想计算一下经过转置卷积后特征图的的变化,即知道输入特征图尺寸以及k、s、p算出输出特征图尺寸,这步骤也太复杂了。于是好奇有没有什么公式可以很方便的计算呢?enmmm,我这么说,那肯定有嘛,公式如下图所示:

image-20220722190016752

​   对于上述公式我做3点说明:

  1. 在转置卷积的官方文档中,参数还有output_padding 和dilation参数也会影响输出特征图的大小,但这里我们没使用,公式就不加上这俩了,感兴趣的可以自己去阅读一下文档,写的很详细。🌵🌵🌵
  2. 对于stride[0],stride[1]、padding[0],padding[1]、kernel_size[0],kernel_size[1]该怎么理解?其实啊这些都是卷积的基本知识,这些参数设置时可以设置一个整数或者一个含两个整数的元组,*[0]表示在高度上进行操作,*[1]表示在宽度上进行操作。有关这部分在官方文档上也有写,大家可自行查看。为方便大家,我截了一下这部分的图片,如下:
  3. 这点我带大家宏观的理解一下这个公式,在传统卷积中,往往卷积核k越小、padding越大,得到的特征图尺寸越大;而在转置卷积中,从公式可以看出,卷积核k越大,padding越小,得到的特征图尺寸越大,关于这一点相信你也能从前文所述的转置卷积理论部分有所感受。🌿🌿🌿

​ 现在有了这个公式,大家再去试试叭。


 

转置卷积实验📝📝📝

​   接下来我将通过一个小实验验证上面的过程,代码如下:

import torch
import torch.nn as nn


#转置卷积
def transposed_conv_official():
    feature_map = torch.as_tensor([[1, 2],
                                   [0, 1]], dtype=torch.float32).reshape([1, 1, 2, 2])
    print(feature_map)
    trans_conv = nn.ConvTranspose2d(in_channels=1, out_channels=1,
                                    kernel_size=3, stride=2, bias=False)
    trans_conv.load_state_dict({"weight": torch.as_tensor([[1, 0, 1],
                                                           [1, 1, 0],
                                                           [0, 0, 1]], dtype=torch.float32).reshape([1, 1, 3, 3])})
    print(trans_conv.weight)
    output = trans_conv(feature_map)
    print(output)


def transposed_conv_self():
    feature_map = torch.as_tensor([[0, 0, 0, 0, 0, 0, 0],
                                   [0, 0, 0, 0, 0, 0, 0],
                                   [0, 0, 1, 0, 2, 0, 0],
                                   [0, 0, 0, 0, 0, 0, 0],
                                   [0, 0, 0, 0, 1, 0, 0],
                                   [0, 0, 0, 0, 0, 0, 0],
                                   [0, 0, 0, 0, 0, 0, 0]], dtype=torch.float32).reshape([1, 1, 7, 7])
    print(feature_map)
    conv = nn.Conv2d(in_channels=1, out_channels=1,
                     kernel_size=3, stride=1, bias=False)
    conv.load_state_dict({"weight": torch.as_tensor([[1, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1]], dtype=torch.float32).reshape([1, 1, 3, 3])})
    print(conv.weight)
    output = conv(feature_map)
    print(output)


def main():
    transposed_conv_official()
    print("---------------")
    transposed_conv_self()


if __name__ == '__main__':
    main()


​   首先我们先通过transposed_conv_official()函数来封装一个转置卷积过程,可以看到我们的输入为[[1,2],[0,1]],卷积核为[[1,0,1],[1,1,0],[0,0,1]],采用k=3,s=2,p=0进行转置卷积【注:这些参数和我前文讲解转置卷积步骤的用例参数是一致的】,我们来看一下程序输出的结果:可以发现程序输出和我们前面理论计算得到的结果是一致的。

image-20220722195837221

​   接着我们封装了transposed_conv_self函数,这个函数定义的是一个正常的卷积,输入是理论第2步得到的特征图,卷积核是第三步翻转后得到的卷积核,经过卷积后输出结果如下:结果和前面的一致。

image-20220722200836873

​   那么通过这个例子就大致证明了转置卷积的步骤确实是我们理论步骤所述。


【呼~~这部分终于讲完了,这部分参考链接如下:参考视频🥎🥎🥎】

 
 

如若文章对你有所帮助,那就🛴🛴🛴

在这里插入图片描述

相关文章:

  • ES字符串从任意位置模糊查询(支持只匹配含连续字符串内容)
  • (附源码)计算机毕业设计ssm基于B_S的汽车售后服务管理系统
  • STM32F103移植FreeRTOS必须搞明白的系列知识---4(FreeRTOSConfig.h配置文件)
  • Java基础2(二维数组、数组的赋值判定)
  • Redis 强化之一
  • 打印设备电磁泄露信息提取和还原技术的matlab仿真实现
  • 【C++】类和对象(中)—— 日期类的实现 | const成员函数
  • 树莓派视频监控项目总结
  • datax与多种数据库间数据类型映射
  • Redis哨兵模式与Redis缓存穿透、击穿和雪崩
  • Ubuntu Budgie 22.04 设置中文语言并安装拼音输入法
  • 4K Star , Github上照片转漫画最强项目
  • Matlab 创建YOLO v2目标检测网络(仅仅是网络)
  • Java集合04:Collection子接口二:Set接口
  • 查看CPU核数、内存使用情况【一文读懂】
  • 实现windows 窗体的自己画,网上摘抄的,学习了
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • js算法-归并排序(merge_sort)
  • Laravel 菜鸟晋级之路
  • spring + angular 实现导出excel
  • uni-app项目数字滚动
  • 从输入URL到页面加载发生了什么
  • 服务器之间,相同帐号,实现免密钥登录
  • 欢迎参加第二届中国游戏开发者大会
  • 判断客户端类型,Android,iOS,PC
  • 如何学习JavaEE,项目又该如何做?
  • 我这样减少了26.5M Java内存!
  • !$boo在php中什么意思,php前戏
  • # Pytorch 中可以直接调用的Loss Functions总结:
  • #if和#ifdef区别
  • $$$$GB2312-80区位编码表$$$$
  • $分析了六十多年间100万字的政府工作报告,我看到了这样的变迁
  • (C#)Windows Shell 外壳编程系列9 - QueryInfo 扩展提示
  • (webRTC、RecordRTC):navigator.mediaDevices undefined
  • (个人笔记质量不佳)SQL 左连接、右连接、内连接的区别
  • (力扣记录)1448. 统计二叉树中好节点的数目
  • (算法)求1到1亿间的质数或素数
  • (原創) 物件導向與老子思想 (OO)
  • (转)淘淘商城系列——使用Spring来管理Redis单机版和集群版
  • .NET业务框架的构建
  • .NET中统一的存储过程调用方法(收藏)
  • /var/log/cvslog 太大
  • [2009][note]构成理想导体超材料的有源THz欺骗表面等离子激元开关——
  • [3D游戏开发实践] Cocos Cyberpunk 源码解读-高中低端机性能适配策略
  • [BeginCTF]真龙之力
  • [c]扫雷
  • [English]英语积累本
  • [IE9] 解决了傲游、搜狗浏览器在IE9下网页截图的问题
  • [JS入门到进阶] 7条关于 async await 的使用口诀,新学 async await?背10遍,以后要考!快收藏
  • [Linux]history 显示命令的运行时间
  • [node] Node.js 缓冲区Buffer
  • [SOC] MBIST (Memory Built-In Self Test) and Memory Built-in Self Repair (BISR)
  • [SpringBoot系列]缓存解决方案
  • [笔记]http权威指南(2)