当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 26 【GoogLeNet_Learning】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上5.9节
此节功能为:含并行连结的网络(GoogLeNet)
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 5.9 含并行连结的网络(GoogLeNet)
# 注释:黄文俊
# E-mail:hurri_cane@qq.com



import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Inception(nn.Module):
    # c1 - c4为每条线路里的层的输出通道数
    def __init__(self, in_c, c1, c2, c3, c4):
        super(Inception, self).__init__()
        # 线路1,单1 x 1卷积层
        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)
        # 线路2,1 x 1卷积层后接3 x 3卷积层
        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 线路3,1 x 1卷积层后接5 x 5卷积层
        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 线路4,3 x 3最大池化层后接1 x 1卷积层
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)

    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))
        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        return torch.cat((p1, p2, p3, p4), dim=1)  # 在通道维上连结输出

# 第一模块使用一个64通道的7×77×7卷积层。
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

# 第二模块使用2个卷积层:首先是64通道的1×11×1卷积层,然后是将通道增大3倍的3×33×3卷积层。
b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),
                   nn.Conv2d(64, 192, kernel_size=3, padding=1),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),
                   Inception(256, 128, (128, 192), (32, 96), 64),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),
                   Inception(512, 160, (112, 224), (24, 64), 64),
                   Inception(512, 128, (128, 256), (24, 64), 64),
                   Inception(512, 112, (144, 288), (32, 64), 64),
                   Inception(528, 256, (160, 320), (32, 128), 128),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),
                   Inception(832, 384, (192, 384), (48, 128), 128),
                   d2l.GlobalAvgPool2d())

net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))
X = torch.rand(1, 1, 96, 96)
for blk in net.children():
    X = blk(X)
    print('output shape: ', X.shape)



batch_size = 512
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)




print("*"*50)

相关文章:

  • 提高.NET Compact Framework 1.0应用程序的窗体加载性能
  • 《动手学深度学习》(PyTorch版)代码注释 - 27 【Batch_normalization_with_zero】
  • 基于.NET Compact Framework的应用程序和库汇总
  • 《动手学深度学习》(PyTorch版)代码注释 - 28 【Batch_normalization_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 29 【ResNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持
  • 《动手学深度学习》(PyTorch版)代码注释 - 30 【DenseNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持(续)
  • 《动手学深度学习》(PyTorch版)代码注释 - 31 【Language_model_data_set】
  • Windows Mobile 6 SDK 中的 Device Emulator 2.0
  • 《动手学深度学习》(PyTorch版)代码注释 - 32 【RNN_with_zero】
  • Windows Mobile 6 SDK 中的 Cellular Emulator
  • Windows Mobile 6 SDK 中的 GPS 工具
  • 《动手学深度学习》(PyTorch版)代码注释 - 33 【RNN_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】
  • “Material Design”设计规范在 ComponentOne For WinForm 的全新尝试!
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • Android 架构优化~MVP 架构改造
  • Angular 4.x 动态创建组件
  • Angularjs之国际化
  • bootstrap创建登录注册页面
  • Git初体验
  • hadoop入门学习教程--DKHadoop完整安装步骤
  • Laravel 菜鸟晋级之路
  • miniui datagrid 的客户端分页解决方案 - CS结合
  • ucore操作系统实验笔记 - 重新理解中断
  • 仿天猫超市收藏抛物线动画工具库
  • 分享一份非常强势的Android面试题
  • 干货 | 以太坊Mist负责人教你建立无服务器应用
  • 简析gRPC client 连接管理
  • 盘点那些不知名却常用的 Git 操作
  • 吴恩达Deep Learning课程练习题参考答案——R语言版
  • 验证码识别技术——15分钟带你突破各种复杂不定长验证码
  • 源码之下无秘密 ── 做最好的 Netty 源码分析教程
  • 白色的风信子
  • [Shell 脚本] 备份网站文件至OSS服务(纯shell脚本无sdk) ...
  • Java总结 - String - 这篇请使劲喷我
  • ​软考-高级-系统架构设计师教程(清华第2版)【第9章 软件可靠性基础知识(P320~344)-思维导图】​
  • ​学习一下,什么是预包装食品?​
  • # 手柄编程_北通阿修罗3动手评:一款兼具功能、操控性的电竞手柄
  • #宝哥教你#查看jquery绑定的事件函数
  • #我与Java虚拟机的故事#连载14:挑战高薪面试必看
  • $(function(){})与(function($){....})(jQuery)的区别
  • $jQuery 重写Alert样式方法
  • (C语言)输入自定义个数的整数,打印出最大值和最小值
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第2节(共同的基类)
  • (超简单)构建高可用网络应用:使用Nginx进行负载均衡与健康检查
  • (二)linux使用docker容器运行mysql
  • (附源码)springboot掌上博客系统 毕业设计063131
  • (免费领源码)Python#MySQL图书馆管理系统071718-计算机毕业设计项目选题推荐
  • (十)【Jmeter】线程(Threads(Users))之jp@gc - Stepping Thread Group (deprecated)
  • (原創) 如何解决make kernel时『clock skew detected』的warning? (OS) (Linux)
  • (转)可以带来幸福的一本书
  • (转)如何上传第三方jar包至Maven私服让maven项目可以使用第三方jar包
  • .NET CF命令行调试器MDbg入门(四) Attaching to Processes