当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 25 【NiN_Learning】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上5.8节
此节功能为:网络中的网络(NiN)
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 5.8 网络中的网络(NiN)
# 注释:黄文俊
# E-mail:hurri_cane@qq.com

import time
import torch
from torch import nn, optim

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU())
    return blk


# 已保存在d2lzh_pytorch
import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    GlobalAvgPool2d(),
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    d2l.FlattenLayer())

# 我们构建一个数据样本来查看每一层的输出形状。
X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children():
    X = blk(X)
    print(name, 'output shape: ', X.shape)

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)




print("*"*50)

相关文章:

  • SQL Server CE:没有足够的存储空间来完成该操作[CODE:8007000E]
  • 《动手学深度学习》(PyTorch版)代码注释 - 26 【GoogLeNet_Learning】
  • 提高.NET Compact Framework 1.0应用程序的窗体加载性能
  • 《动手学深度学习》(PyTorch版)代码注释 - 27 【Batch_normalization_with_zero】
  • 基于.NET Compact Framework的应用程序和库汇总
  • 《动手学深度学习》(PyTorch版)代码注释 - 28 【Batch_normalization_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 29 【ResNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持
  • 《动手学深度学习》(PyTorch版)代码注释 - 30 【DenseNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持(续)
  • 《动手学深度学习》(PyTorch版)代码注释 - 31 【Language_model_data_set】
  • Windows Mobile 6 SDK 中的 Device Emulator 2.0
  • 《动手学深度学习》(PyTorch版)代码注释 - 32 【RNN_with_zero】
  • Windows Mobile 6 SDK 中的 Cellular Emulator
  • Windows Mobile 6 SDK 中的 GPS 工具
  • [译]Python中的类属性与实例属性的区别
  • 2017 前端面试准备 - 收藏集 - 掘金
  • crontab执行失败的多种原因
  • Elasticsearch 参考指南(升级前重新索引)
  • golang 发送GET和POST示例
  • JavaScript 是如何工作的:WebRTC 和对等网络的机制!
  • Linux后台研发超实用命令总结
  • Redis中的lru算法实现
  • ucore操作系统实验笔记 - 重新理解中断
  • Vue学习第二天
  • 第十八天-企业应用架构模式-基本模式
  • 深度学习入门:10门免费线上课程推荐
  • 算法系列——算法入门之递归分而治之思想的实现
  • 译有关态射的一切
  • 硬币翻转问题,区间操作
  • ​什么是bug?bug的源头在哪里?
  • #!/usr/bin/python与#!/usr/bin/env python的区别
  • #### go map 底层结构 ####
  • #我与Java虚拟机的故事#连载01:人在JVM,身不由己
  • ()、[]、{}、(())、[[]]等各种括号的使用
  • (C语言)编写程序将一个4×4的数组进行顺时针旋转90度后输出。
  • (C语言)输入自定义个数的整数,打印出最大值和最小值
  • (MIT博士)林达华老师-概率模型与计算机视觉”
  • (备忘)Java Map 遍历
  • (附源码)spring boot智能服药提醒app 毕业设计 102151
  • (附源码)ssm教材管理系统 毕业设计 011229
  • (接口自动化)Python3操作MySQL数据库
  • (四)鸿鹄云架构一服务注册中心
  • (一)使用IDEA创建Maven项目和Maven使用入门(配图详解)
  • (转)http协议
  • .bat批处理(七):PC端从手机内复制文件到本地
  • .net core使用ef 6
  • .Net Web窗口页属性
  • .NET4.0并行计算技术基础(1)
  • .net知识和学习方法系列(二十一)CLR-枚举
  • @private @protected @public
  • [Angular] 笔记 16:模板驱动表单 - 选择框与选项
  • [BZOJ5250][九省联考2018]秘密袭击(DP)
  • [ERROR ImagePull]: failed to pull image k8s.gcr.io/kube-controller-manager失败
  • [leveldb] 2.open操作介绍