当前位置: 首页 > news >正文

『PyTorch x TensorFlow』第八弹_基本nn.Module层函数

『TensorFlow』网络操作API_上  

『TensorFlow』网络操作API_中

『TensorFlow』网络操作API_下

之前也说过,tf 和 t 的层本质区别就是 tf 的是层函数,调用即可,t 的是类,需要初始化后再调用实例(实例都是callable的)

 

卷积

tensorflow.nn.conv2d

import tensorflow as tf

sess = tf.Session()
input = tf.Variable(tf.random_normal([1,3,3,5]))

# 卷积核尺寸*2,输入通道,输出通道,
filter = tf.Variable(tf.random_normal([1,1,5,1])) # 《-----卷积核初始化

conv = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

sess.run(tf.global_variables_initializer())
print(sess.run(conv).shape)
(1, 3, 3, 1)

torch.nn.Conv2d

troch集成了初始化核的部分,所以自行初始化时需要直接修改变量的data

本篇很多例子中都对module的属性直接操作,其大多数是可学习参数,一般会随着学习的进行而不断改变。实际使用中除非需要使用特殊的初始化,应尽量不要直接修改这些参数。

import torch as t
input = t.normal(means=t.zeros([1,5,3,3]), std=t.Tensor([0.1]).expand([1,5,3,3]))
input = t.autograd.Variable(input)

# 输入通道,输出通道,卷积核尺寸,步长,是否偏执
conv = t.nn.Conv2d(5, 1, (1, 1), 1, bias=False)

# 输出通道,输入通道,卷积核尺寸*2
print([n for n,p in conv.named_parameters()])
conv.weight.data = t.ones([1,5,1,1]) # 《-----卷积核初始化,可有可无

out = conv(input)
print(out.size())
['weight']
torch.Size([1, 1, 3, 3])

 

池化

tensorflow.nn.avg_pool

torch.nn.AvgPool2d

可以验证没有学习参数

pool = nn.AvgPool2d(2,2)
list(pool.parameters())
[]

 

线性

torch.nn.Linear

# 输入 batch_size=2,维度3
input = V(t.randn(2, 3))
linear = nn.Linear(3, 4)
h = linear(input)
print(h)
Variable containing:
-1.4189 -0.2045  1.2143 -1.5404
 0.8471 -0.3154 -0.5855  0.0153
[torch.FloatTensor of size 2x4]

 

BatchNorm

『TensorFlow』批处理类

torch.nn.BatchNorm1d

BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。

# 4 channel,初始化标准差为4,均值为0
bn = nn.BatchNorm1d(4)
print([n for n,p in bn.named_parameters()])
bn.weight.data = t.ones(4) * 4
bn.bias.data = t.zeros(4)

bn_out = bn(h)
# 注意输出的均值和方差
# 方差是标准差的平方,计算无偏方差分母会减1
# 使用unbiased=False 分母不减1
bn_out.size(), bn_out.mean(0), bn_out.var(0, unbiased=False)
['weight', 'bias']
(torch.Size([2, 4]), 

Variable containing: 1.00000e-06 * 0.0000 -1.0729 0.0000 0.1192 [torch.FloatTensor of size 4],

Variable containing: 15.9999 15.9481 15.9998 15.9997 [torch.FloatTensor of size 4])

 

Dropout

tensorflow.nn.dropout

torch.nn.Dropout

dropout层,用来防止过拟合,同样分为1D、2D和3D。 下面通过例子来说明它们的使用。

 # 每个元素以0.5的概率舍弃
dropout = nn.Dropout(0.5)
o = dropout(bn_out)
o # 有一半左右的数变为0
Variable containing:
-7.9895 -7.9931  7.9991  7.9973
 0.0000  0.0000 -7.9991 -7.9973
[torch.FloatTensor of size 2x4]

 

激活函数

PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档^3,这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:

relu = nn.ReLU(inplace=True)
input = V(t.randn(2, 3))
print(input)
output = relu(input)
print(output) # 小于0的都被截断为0
# 等价于input.clamp(min=0)
Variable containing:
-0.8472  1.0046  0.7245
 0.3567  0.0032 -0.5200
[torch.FloatTensor of size 2x3]

Variable containing:
 0.0000  1.0046  0.7245
 0.3567  0.0032  0.0000
[torch.FloatTensor of size 2x3]

有关inplace:

ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如variable.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。

 

交叉熵

import torch as t
from torch.autograd import Variable as V

score = V(t.randn(3,2))
label = V(t.Tensor([1,0,1])).long()
loss_fn = t.nn.CrossEntropyLoss()
loss = loss_fn(score,label)
print(loss)

Variable containing:
 1.3535
[torch.FloatTensor of size 1]

损失函数和nn.Module的其他class没什么不同,不过实际使用时往往单独提取出来(书上语)。

 

ReLU(x)=max(0,x)

 

相关文章:

  • 解决-webkit-box-orient: vertical;(文本溢出)属性在webpack打包后无法编译的问题
  • Linux快速构建LAMP网站平台
  • python 装饰器(一)
  • kali:加速WEP黑客攻击,ARP请求重播攻击
  • DM8127-UART驱动
  • 利用表格分页显示数据的js组件datatable的使用
  • RAID磁盘阵列种类及区别
  • Linux LVM 之LV
  • 语音识别技术受追捧,无法独立工作的“速记神器”何时才能成为新亮点?
  • 还在啃老?是该来场逼格满满的产品展示了!
  • 2018年微信小程序风口最新发展趋势分析
  • Fortinet安全能力融入华为CloudEPN 联合防御网络威胁
  • 【Java资源免费分享,网盘自己拿】
  • 洛谷2774:[网络流24题]方格取数问题——题解
  • 第五届全球云计算大会暨国际网络通信展览会·中国站圆满落幕
  • [译] 怎样写一个基础的编译器
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • angular组件开发
  • Docker 1.12实践:Docker Service、Stack与分布式应用捆绑包
  • Dubbo 整合 Pinpoint 做分布式服务请求跟踪
  • happypack两次报错的问题
  • Linux中的硬链接与软链接
  • PHP的类修饰符与访问修饰符
  • Python进阶细节
  • React as a UI Runtime(五、列表)
  • vue数据传递--我有特殊的实现技巧
  • 阿里云前端周刊 - 第 26 期
  • 不发不行!Netty集成文字图片聊天室外加TCP/IP软硬件通信
  • 程序员最讨厌的9句话,你可有补充?
  • 官方解决所有 npm 全局安装权限问题
  • 微信小程序--------语音识别(前端自己也能玩)
  • 学习HTTP相关知识笔记
  • 字符串匹配基础上
  • media数据库操作,可以进行增删改查,实现回收站,隐私照片功能 SharedPreferences存储地址:
  • MPAndroidChart 教程:Y轴 YAxis
  • mysql 慢查询分析工具:pt-query-digest 在mac 上的安装使用 ...
  • 格斗健身潮牌24KiCK获近千万Pre-A轮融资,用户留存高达9个月 ...
  • ​Java并发新构件之Exchanger
  • ​学习一下,什么是预包装食品?​
  • #NOIP 2014#day.2 T1 无限网络发射器选址
  • #QT项目实战(天气预报)
  • #微信小程序:微信小程序常见的配置传值
  • (MIT博士)林达华老师-概率模型与计算机视觉”
  • (多级缓存)缓存同步
  • (三)Hyperledger Fabric 1.1安装部署-chaincode测试
  • .bat批处理(二):%0 %1——给批处理脚本传递参数
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .Net Core缓存组件(MemoryCache)源码解析
  • .NET 设计模式初探
  • .netcore 如何获取系统中所有session_如何把百度推广中获取的线索(基木鱼,电话,百度商桥等)同步到企业微信或者企业CRM等企业营销系统中...
  • [\u4e00-\u9fa5] //匹配中文字符
  • [android] 练习PopupWindow实现对话框
  • [AutoSar]BSW_Com02 PDU详解
  • [bzoj 3534][Sdoi2014] 重建
  • [C# 开发技巧]如何使不符合要求的元素等于离它最近的一个元素