当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 30 【DenseNet_Learning】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上5.12节
此节功能为:稠密连接网络(DenseNet)
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 5.11 残差网络(ResNet)
# 注释:黄文俊
# E-mail:hurri_cane@qq.com

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。
# 前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。
def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels),
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk


class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X


blk1 = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk1(X)
print(Y.shape)

# 5.12.2 过渡层
# 通过1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。
def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))  # 平均池化
    return blk

blk = transition_block(23, 10)
print(blk(Y).shape)


# 构造DenseNet模型

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))


num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2


# 最后接上全局池化层和全连接层来输出
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10)))

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)



print("*"*50)

相关文章:

  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持(续)
  • 《动手学深度学习》(PyTorch版)代码注释 - 31 【Language_model_data_set】
  • Windows Mobile 6 SDK 中的 Device Emulator 2.0
  • 《动手学深度学习》(PyTorch版)代码注释 - 32 【RNN_with_zero】
  • Windows Mobile 6 SDK 中的 Cellular Emulator
  • Windows Mobile 6 SDK 中的 GPS 工具
  • 《动手学深度学习》(PyTorch版)代码注释 - 33 【RNN_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】
  • MEDC2007北京游记 - WindowsMobile Ophone
  • 《动手学深度学习》(PyTorch版)代码注释 - 35 【GRU_with_simple_way】
  • 祝贺CICI拿到VISA
  • 《动手学深度学习》(PyTorch版)代码注释 - 36 【LSTM_with_zero】
  • WPF/E去了,Silverlight来了
  • iPhone - 少一点自恋,多一点现实 !
  • 《动手学深度学习》(PyTorch版)代码注释 - 37 【LSTM_with_simple_way】
  • Apache的基本使用
  • docker python 配置
  • Linux学习笔记6-使用fdisk进行磁盘管理
  • Linux中的硬链接与软链接
  • Lsb图片隐写
  • ng6--错误信息小结(持续更新)
  • supervisor 永不挂掉的进程 安装以及使用
  • 闭包,sync使用细节
  • 程序员最讨厌的9句话,你可有补充?
  • 得到一个数组中任意X个元素的所有组合 即C(n,m)
  • 给第三方使用接口的 URL 签名实现
  • 数据库写操作弃用“SELECT ... FOR UPDATE”解决方案
  • 学习笔记TF060:图像语音结合,看图说话
  • 3月27日云栖精选夜读 | 从 “城市大脑”实践,瞭望未来城市源起 ...
  • 带你开发类似Pokemon Go的AR游戏
  • ​软考-高级-信息系统项目管理师教程 第四版【第19章-配置与变更管理-思维导图】​
  • $分析了六十多年间100万字的政府工作报告,我看到了这样的变迁
  • (MATLAB)第五章-矩阵运算
  • (阿里巴巴 dubbo,有数据库,可执行 )dubbo zookeeper spring demo
  • (笔试题)分解质因式
  • (附源码)node.js知识分享网站 毕业设计 202038
  • (附源码)springboot掌上博客系统 毕业设计063131
  • (附源码)计算机毕业设计SSM智能化管理的仓库管理
  • *p++,*(p++),*++p,(*p)++区别?
  • . NET自动找可写目录
  • .NET Micro Framework 4.2 beta 源码探析
  • .net 获取url的方法
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • .Net+SQL Server企业应用性能优化笔记4——精确查找瓶颈
  • .Net程序帮助文档制作
  • .net连接oracle数据库
  • .net图片验证码生成、点击刷新及验证输入是否正确
  • .NET与 java通用的3DES加密解密方法
  • @Not - Empty-Null-Blank
  • @property括号内属性讲解
  • @Resource和@Autowired的区别
  • [Arduino学习] ESP8266读取DHT11数字温湿度传感器数据
  • [AutoSar]BSW_Com07 CAN报文接收流程的函数调用
  • [BUUCTF 2018]Online Tool
  • [BZOJ 3680]吊打XXX(模拟退火)