当前位置: 首页 > news >正文

【深度学习】卷积神经网络的架构参考 以MNIST数据集为例(未调节架构)测试集正确率: 98.66%

文章目录

  • 下面是实现的代码
  • 1. 读取数据
  • 2. 卷积神经网络的构建
      • 语法解惑
        • 关于类的初始化参数
        • 关于Sequential
  • 3. 准确率作为评估的标准
    • troch.max()
  • 4. 开始训练网络模型
      • 开启训练的思路
  • 写在最后


本文通过总结网上课程的构建深度学习(卷积)神经网络的主要架构,为大家写CNN提供参考,并在一些代码中间加注了一些解释方法(供大家参考学习)

希望能帮到你 😃 ~


下面是实现的代码

导入数据包

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib
Using matplotlib backend: TkAgg

1. 读取数据

  • 分别构件训练集和测试集
  • DataLoader来迭代数据
# 定义超参数
input_size = 28  # 图像尺寸是 28 * 28
num_classes = 10 # 标签的种类数
num_epochs = 3   # 图像训练的总时长
batch_size = 64  # 一个批次送进出图像数量的大小,这里一次性送进去64张图片

# 训练集
train_dataset = datasets.MNIST(root = './data',
                            train = True,
                           transform = transforms.ToTensor(),
                           download = True)

# 测试集
test_dataset = datasets.MNIST(root = './data',
                             train = False,
                             transform = transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                          batch_size = batch_size,
                                          shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                         batch_size = batch_size,
                                         shuffle = True)

在构件完成batch之后,我们只需要在batch中一个一个去取数据就行了

DataLoader:

  • dataset(Dataset): 传入的数据集

  • batch_size(int, optional): 每个batch有多少个样本

  • shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序


2. 卷积神经网络的构建

  • 一般是卷积层, relu层, 池化层可以变成一个小单元
  • 在网络的最后,我们应该加上一个特征图,转化为分类或者回归任务
  • 如果卷积希望输出的结果的size一样,则需要设置 padding = (kernel_size - 1) / 2(srtide = 1 的情况)
  • Conv2d不管你每一层输出的特征图大小的,直接照单全收,但是输入的channel数要搞好(相当于只管好自己的卷积核输入就行)

注意def forward 要与init的缩进对齐!

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__() # 子类调用父类的方法使用父类进行初始化
        
        # 第一个卷积单元
        self.conv1 = nn.Sequential( # 输入大小(1, 28, 28) 第一个卷积模块
            nn.Conv2d(
                in_channels = 1,     # 灰度图
                out_channels = 16,  # 得到特征图的个数
                kernel_size = 5,    # 卷积核的大小
                stride = 1,         # 步长
                padding = 2,        # 输出的特征图为 (28, 28, 16)
            ), # 输出结构为
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2), # 2 * 2 大小的矩阵做maxpool
            # 输出(14, 14, 16)
        )
        
        # 第二个卷积单元
        # 上一个模块的channel out 为下一个模块的channel in
        self.conv2 = nn.Sequential(     # 输入(16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2), # 对应的参数上面的一个已经写了
            # 输出大小 (32,14,14)
            nn.ReLU(),
            nn.MaxPool2d(2), # 输出大小 (32,7,7)
        )
        
        # 将输出的向量拉直
        # 输入神经元个数, 输出神经元个数,是否包含偏置值
        self.out = nn.Linear(32 * 7 * 7,  10)
        
    # 定义一个前推 规则方法
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # 将所有的结果转为向量的形式
        x = x.view(x.size(0), -1) # flatten 操作,结果为: (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

语法解惑

关于类的初始化参数

self:通过这个类实例化后的对象本身,你对这个对象有什么要求

  1. 在定义类中的方法时,不能省略self 如 eat(self)
  2. def init(self, name): 方法相当于构建函数,会自动执行,如果在初始化类的时候没有传入足够的参数,函数将会报错
  3. super语句:子类把父类的__init__()放到自己的__init__()当中,这样一来,子类就有了父类的东西
  4. (续上)也就是我们定义的类CNN拥有了传入参数模型 nn.Module中所定义的东西
  5. (续上)如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的

奋斗の博客_解惑(一) ----- super(XXX, self).init()到底是代表什么含义
这篇文章真的写的很好,大家可以去看一下~

关于Sequential

pytorch系列7 -----nn.Sequential讲解_墨氲一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行


3. 准确率作为评估的标准

注解:

troch.max()

该函数对输出的结果进行分类
output = torch.max(input, dim)

  • 输入
    input是softmax函数输出的一个tensor
    dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
  • 输出
    函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

https://www.jianshu.com/p/3ed11362b54f

def accuracy(predictions, labels):
    # prediction的形状(batch_size * 10)
    # 按照行优先,每10个10个地取出来该行内的最大值(下标)
    # 下标对应的恰好就是0-9这10个数值
    pred = torch.max(predictions.data, 1)[1]
    
    # 与原来的结果进行比较
    # 对了的就是1,不对就是0,然后求和
    rights = pred.eq(labels.data.view_as(pred)).sum()
    
    return rights, len(labels)

4. 开始训练网络模型

DataLoader使用enumerate来遍历,逐个取batch

开启训练的思路

  • 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  • optimizer.zero_grad() 清空过往梯度;
  • loss.backward() 反向传播,计算当前梯度;
  • optimizer.step() 根据梯度更新网络参数
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = optim.Adam(net.parameters(), lr = 0.001) # 定义优化器

# 开始训练循环
for epoch in range(num_epochs):
    # 当前的epoch的结果保存下来
    # 每一轮循环开始前先清零
    train_rights = []
    
    for batch_idx, (data, target) in enumerate(train_loader):
        # 针对容器中的每一个batch进行循环
        net.train()
        output = net(data)
        loss = criterion(output,  target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(output, target)
        train_rights.append(right)
        
        if batch_idx % 100 == 0:
            
            net.eval()
            val_rights = []
            
            for(data, target) in test_loader:
                output = net(data)
                right = accuracy(output, target)
                val_rights.append(right)
                
            # 计算准确率
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
            
            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                    epoch, batch_idx * batch_size, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.data,
                    100. * train_r[0].numpy() / train_r[1],
                    100. * val_r[0].numpy() / val_r[1]))
当前epoch: 0 [0/60000 (0%)]	损失: 2.295074	训练集准确率: 15.62%	测试集正确率: 12.67%
当前epoch: 0 [6400/60000 (11%)]	损失: 0.286456	训练集准确率: 77.10%	测试集正确率: 92.20%
当前epoch: 0 [12800/60000 (21%)]	损失: 0.252975	训练集准确率: 85.43%	测试集正确率: 94.98%
当前epoch: 0 [19200/60000 (32%)]	损失: 0.117255	训练集准确率: 88.77%	测试集正确率: 96.69%
当前epoch: 0 [25600/60000 (43%)]	损失: 0.093021	训练集准确率: 90.83%	测试集正确率: 97.19%
当前epoch: 0 [32000/60000 (53%)]	损失: 0.171856	训练集准确率: 92.10%	测试集正确率: 96.74%
当前epoch: 0 [38400/60000 (64%)]	损失: 0.024938	训练集准确率: 92.94%	测试集正确率: 97.28%
当前epoch: 0 [44800/60000 (75%)]	损失: 0.045772	训练集准确率: 93.60%	测试集正确率: 98.13%
当前epoch: 0 [51200/60000 (85%)]	损失: 0.036130	训练集准确率: 94.11%	测试集正确率: 98.08%
当前epoch: 0 [57600/60000 (96%)]	损失: 0.019689	训练集准确率: 94.52%	测试集正确率: 98.32%
当前epoch: 1 [0/60000 (0%)]	损失: 0.014988	训练集准确率: 100.00%	测试集正确率: 98.50%
当前epoch: 1 [6400/60000 (11%)]	损失: 0.039536	训练集准确率: 97.94%	测试集正确率: 98.42%
当前epoch: 1 [12800/60000 (21%)]	损失: 0.053510	训练集准确率: 97.99%	测试集正确率: 98.42%
当前epoch: 1 [19200/60000 (32%)]	损失: 0.168834	训练集准确率: 98.18%	测试集正确率: 98.56%
当前epoch: 1 [25600/60000 (43%)]	损失: 0.084785	训练集准确率: 98.16%	测试集正确率: 98.46%
当前epoch: 1 [32000/60000 (53%)]	损失: 0.036668	训练集准确率: 98.21%	测试集正确率: 98.43%
当前epoch: 1 [38400/60000 (64%)]	损失: 0.019413	训练集准确率: 98.23%	测试集正确率: 98.57%
当前epoch: 1 [44800/60000 (75%)]	损失: 0.063640	训练集准确率: 98.25%	测试集正确率: 98.57%
当前epoch: 1 [51200/60000 (85%)]	损失: 0.022920	训练集准确率: 98.28%	测试集正确率: 98.48%
当前epoch: 1 [57600/60000 (96%)]	损失: 0.019340	训练集准确率: 98.31%	测试集正确率: 98.80%
当前epoch: 2 [0/60000 (0%)]	损失: 0.050389	训练集准确率: 98.44%	测试集正确率: 98.67%
当前epoch: 2 [6400/60000 (11%)]	损失: 0.007058	训练集准确率: 98.90%	测试集正确率: 98.73%
当前epoch: 2 [12800/60000 (21%)]	损失: 0.054998	训练集准确率: 98.76%	测试集正确率: 98.74%
当前epoch: 2 [19200/60000 (32%)]	损失: 0.050147	训练集准确率: 98.79%	测试集正确率: 98.70%
当前epoch: 2 [25600/60000 (43%)]	损失: 0.012523	训练集准确率: 98.77%	测试集正确率: 98.83%
当前epoch: 2 [32000/60000 (53%)]	损失: 0.046355	训练集准确率: 98.76%	测试集正确率: 98.71%
当前epoch: 2 [38400/60000 (64%)]	损失: 0.061045	训练集准确率: 98.78%	测试集正确率: 98.78%
当前epoch: 2 [44800/60000 (75%)]	损失: 0.060261	训练集准确率: 98.79%	测试集正确率: 98.71%
当前epoch: 2 [51200/60000 (85%)]	损失: 0.036363	训练集准确率: 98.83%	测试集正确率: 98.66%
当前epoch: 2 [57600/60000 (96%)]	损失: 0.041269	训练集准确率: 98.81%	测试集正确率: 98.66%

共训练了2个epoch,达到了99.66%的准确率,说明在图像分类的任务里面,小卷积核的CNN有不错的表现效果~


写在最后

各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
本文代码参考课程:五大深度神经网络基础 Lesson14
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知

相关文章:

  • C++ 哈希桶模拟实现(补充)
  • Rethinking the Inception Architecture for Computer Vision--Christian Szegedy
  • 安卓毕业设计成品基于Uniapp+SSM实现的智能课堂管理APP在线学习网
  • 基于metaRTC嵌入式webrtc的H265网页播放器实现(我与metaRTC的缘分)
  • 【设计模式】Java设计模式 - 组合模式
  • Android之Handler(上)
  • 网络协议:网络安全
  • php防止SQL注入的网上二手交易平台的设计与实现毕业设计-附源码241552
  • 美团笔试题目(Java后端5题2小时)
  • HTML期末大学生网页设计作业——奇恩动漫HTML (1页面) HTML CSS JS网页设计期末课程大作业
  • 浅谈如何学习网络编程
  • 【MYSQL】表的增删改查
  • 中国地板工具租赁服务行业竞争态势与经营效益预测报告2022-2028年
  • 查看docker 容器的端口
  • xubuntu16.04系统中隐藏网络连接的弹窗提示
  • 分享的文章《人生如棋》
  • 「前端早读君006」移动开发必备:那些玩转H5的小技巧
  • 2018以太坊智能合约编程语言solidity的最佳IDEs
  • Hexo+码云+git快速搭建免费的静态Blog
  • Java-详解HashMap
  • js中的正则表达式入门
  • k8s如何管理Pod
  • MySQL数据库运维之数据恢复
  • opencv python Meanshift 和 Camshift
  • Transformer-XL: Unleashing the Potential of Attention Models
  • TypeScript实现数据结构(一)栈,队列,链表
  • Vue 动态创建 component
  • 从0到1:PostCSS 插件开发最佳实践
  • 对话 CTO〡听神策数据 CTO 曹犟描绘数据分析行业的无限可能
  • 基于MaxCompute打造轻盈的人人车移动端数据平台
  • 理清楚Vue的结构
  • 世界编程语言排行榜2008年06月(ActionScript 挺进20强)
  • 一文看透浏览器架构
  • 第二十章:异步和文件I/O.(二十三)
  • ​低代码平台的核心价值与优势
  • #pragma预处理命令
  • #Z2294. 打印树的直径
  • $L^p$ 调和函数恒为零
  • (react踩过的坑)Antd Select(设置了labelInValue)在FormItem中initialValue的问题
  • (Redis使用系列) Springboot 实现Redis消息的订阅与分布 四
  • (附源码)计算机毕业设计SSM疫情居家隔离服务系统
  • (免费领源码)python#django#mysql公交线路查询系统85021- 计算机毕业设计项目选题推荐
  • (十八)SpringBoot之发送QQ邮件
  • .apk文件,IIS不支持下载解决
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .NET 使用 XPath 来读写 XML 文件
  • .one4-V-XXXXXXXX勒索病毒数据怎么处理|数据解密恢复
  • [ vulhub漏洞复现篇 ] JBOSS AS 5.x/6.x反序列化远程代码执行漏洞CVE-2017-12149
  • [ 攻防演练演示篇 ] 利用通达OA 文件上传漏洞上传webshell获取主机权限
  • [bzoj 3124][sdoi 2013 省选] 直径
  • [BZOJ1040][P2607][ZJOI2008]骑士[树形DP+基环树]
  • [C#小技巧]如何捕捉上升沿和下降沿
  • [C++] cout、wcout无法正常输出中文字符问题的深入调查(1):各种编译器测试
  • [C++数据结构](22)哈希表与unordered_set,unordered_map实现
  • [codeforces]Levko and Permutation