当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 29 【ResNet_Learning】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上5.11节
此节功能为:残差网络(ResNet)
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 5.11 残差网络(ResNet)
# 注释:黄文俊
# E-mail:hurri_cane@qq.com


import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Residual(nn.Module):  # 本类已保存在d2lzh_pytorch包中方便以后使用
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)


blk = Residual(3, 3)
X = torch.rand((4, 3, 6, 6))
print(blk(X).shape)

# 我们也可以在增加输出通道数的同时减半输出的高和宽。
blk = Residual(3, 6, use_1x1conv=True, stride=2)
print(blk(X).shape)

print("*"*50)


# ResNet模型
net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    # num_residuals:残差数
    if first_block:
        assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)


net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))

net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10)))

X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)


batch_size = 512
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)





print("*"*50)

相关文章:

  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持
  • 《动手学深度学习》(PyTorch版)代码注释 - 30 【DenseNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持(续)
  • 《动手学深度学习》(PyTorch版)代码注释 - 31 【Language_model_data_set】
  • Windows Mobile 6 SDK 中的 Device Emulator 2.0
  • 《动手学深度学习》(PyTorch版)代码注释 - 32 【RNN_with_zero】
  • Windows Mobile 6 SDK 中的 Cellular Emulator
  • Windows Mobile 6 SDK 中的 GPS 工具
  • 《动手学深度学习》(PyTorch版)代码注释 - 33 【RNN_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】
  • MEDC2007北京游记 - WindowsMobile Ophone
  • 《动手学深度学习》(PyTorch版)代码注释 - 35 【GRU_with_simple_way】
  • 祝贺CICI拿到VISA
  • 《动手学深度学习》(PyTorch版)代码注释 - 36 【LSTM_with_zero】
  • WPF/E去了,Silverlight来了
  • CentOS学习笔记 - 12. Nginx搭建Centos7.5远程repo
  • electron原来这么简单----打包你的react、VUE桌面应用程序
  • Java编程基础24——递归练习
  • overflow: hidden IE7无效
  • Rancher如何对接Ceph-RBD块存储
  • SQL 难点解决:记录的引用
  • STAR法则
  • webpack4 一点通
  • 计算机常识 - 收藏集 - 掘金
  • 数组的操作
  • 一个SAP顾问在美国的这些年
  • raise 与 raise ... from 的区别
  • ​如何使用ArcGIS Pro制作渐变河流效果
  • !!Dom4j 学习笔记
  • # MySQL server 层和存储引擎层是怎么交互数据的?
  • ###51单片机学习(2)-----如何通过C语言运用延时函数设计LED流水灯
  • #define
  • #QT(TCP网络编程-服务端)
  • (2)STM32单片机上位机
  • (DFS + 剪枝)【洛谷P1731】 [NOI1999] 生日蛋糕
  • (附源码)spring boot智能服药提醒app 毕业设计 102151
  • (十一)图像的罗伯特梯度锐化
  • (四) 虚拟摄像头vivi体验
  • (四)Android布局类型(线性布局LinearLayout)
  • (四)docker:为mysql和java jar运行环境创建同一网络,容器互联
  • (四)搭建容器云管理平台笔记—安装ETCD(不使用证书)
  • (一)SpringBoot3---尚硅谷总结
  • (转)C#开发微信门户及应用(1)--开始使用微信接口
  • (转)http-server应用
  • .bat批处理(八):各种形式的变量%0、%i、%%i、var、%var%、!var!的含义和区别
  • .NET 程序如何获取图片的宽高(框架自带多种方法的不同性能)
  • .Net开发笔记(二十)创建一个需要授权的第三方组件
  • .NET应用架构设计:原则、模式与实践 目录预览
  • [ vulhub漏洞复现篇 ] Django SQL注入漏洞复现 CVE-2021-35042
  • [AndroidStudio]_[初级]_[修改虚拟设备镜像文件的存放位置]
  • [Angular] 笔记 6:ngStyle
  • [C++][基础]1_变量、常量和基本类型
  • [CISCN2019 华东北赛区]Web2
  • [cocos creator]EditBox,editing-return事件,清空输入框
  • [C语言]——函数递归