当前位置: 首页 > news >正文

神经网络识别数字图像案例

学习资料:从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

这个视频讲得相当清楚。本文是学习笔记,不是原创,图都是从视频上截图的。

1. 神经网络

2. 案例说明

具体来说,设计一个三层的神经网络。以数字图像作为输入,经过神经网络的计算,识别出图像中的数字是几,从而实现数字图像的分类。

3. 视频讲解内容的提纲

4. 神经网络的设计和实现

我们要处理的数据是28*28像素的灰色通道图像。

这样的灰色图像包括了28*28=784个数据点。需要先将他展平为1*784大小的向量。然后将这个向量输入到神经网络中。

用一个三层神经网络处理图片对应的向量X。输入成需要接收784维的图片向量X。X里面每个维度的数据都有一个神经元来接收。因此输入层要包含784个神经元。

隐藏成用于特征提取特征向量,将输入的特征向量处理成更高级的特征向量。

因为手写数字图像识别并不复杂,所以将隐藏层的神经元个数设置为256。这样,输入层和隐藏层之间就会有个784*256的线性层。它可以将一个784维的输入向量转换为256维的输出向量。

该输出向量会继续向前传播到达输出层。

由于最终要将数字图像识别为0~9,十种可能的数字。因此,输出层需要定义10个神经元,对应这十种数字。

256维的向量在经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。这个10维的向量就代表了10个数字的预测得分。

为了继续得到输出层的预测概率,还要将输出层的输出输入到softmax层。softmax层会将10维的向量转换为10个概率值p0~p9。p0~p9相加的总和等于1.

5. 神经网络的Pytorch实现

import torch
from torch import nn# 定义神经网络Network
class Network(nn.Module):def __init__(self):super().__init__()# 线性层1,输入层和隐藏层之间的线性层self.layer1 = nn.Linear(784, 258)# 线性层2,隐藏层和输出层之间的线性层self.layer2 = nn.Linear(256, 10)# 在前向传播,forward函数中,输入为图像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函数,将x展平x = self.layer1(x) # 将x输入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 输入至layer2计算结果# 这里没有直接定义softmax层,因为后面会使用CrossEntropyLoss损失函数# 在这个损失函数中,会实现softmax的计算

6. 训练数据的准备和处理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初学只要知道大致的数据处理流程即可
if __name__ == '__main__'# 实现图像的预处理pipelinetransform = trnasforms.Compose([# 转换成单通道灰度图transforms.Grayscale(num_output_channels=1),# 转换为张量transforms.ToTensor()])# 使用ImageFolder函数,读取数据文件夹,构建数据集dataset# 这个函数会将保持数据的文件夹的名字,作为数据的标签,组织数据train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他们的长度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 实现小批量的数据读取# 这里设置小批量的大小,batch_size=64. 也就是每个批次,包括64个数据train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的长度print("train_loader length: ", len(train_loader))# 6000个训练数据,如果每个小批量,读入64个样本,那么60000个数据会被分成938组# 938*64=60032,说明最后一组不够64个数据# 循环遍历train_loader# 每一次循环,都会取出64个图像数据,作为一个小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 数据的尺寸print("label: ", label.shape) # 图像中的数字print(label)

7. 模型的训练和测试

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 图像的预处理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读入并构造数据集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的数据读入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch训练模型时,需要创建三个对象:model = Network() # 1.模型本身,就是我们设计的神经网络optimizer = optim.Adam(model.parameters()) #2.优化器,优化模型中的参数criterion = nn.CrossEntropyLoss() #3.损失函数,分类问题,使用交叉熵损失误差# 进入模型的循环迭代# 外层循环,代表了整个训练数据集的遍历次数for epoch in range(10):# 内层循环使用train_loader, 进行小批量的数据读取for batch_idx, (data, label) in enumerate(train_loader):# 内层每循环一次,就会进行一次梯度下降算法# 包括了5个步骤# 这5个步骤是使用pytorch框架训练模型的定式,初学时先记住即可# 1. 计算神经网络的前向传播结果output = model(data)# 2. 计算output和标签label之间的损失lossloss = criterion(output, label)# 3. 使用backward计算梯度loss.backward()# 4. 使用optimizer.step更新参数optimizer.step()# 5.将梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 读取测试数据集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定义神经网络模型model.load_state_dict(torch.load('mnist.pth')) # 加载刚刚训练好的模型文件rigth = 0 # 保存正确识别的数量for i, (x, y) in enumerate(test_dataset):output = model(x) # 将其中的数据x输入到模型predict = output.argmax(1).item() # 选择概率最大标签的作为预测结果# 对比预测值predict和真实标签yif predict == y:right += 1else:# 将识别错误的样例打印出来img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 计算出测试效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 昇思训练营打卡第二十四天(LSTM+CRF序列标注)
  • uniapp 小程序注册全局弹窗组件(无需引入,无需写标签)
  • 缓存与分布式锁
  • T113-i 倒车低概率性无反应,没有进入倒车视频界面
  • Spring-Cache 缓存
  • Zookeeper背景优缺点,以及应用场景
  • 头歌资源库(32)n皇后问题
  • 【坑】微信小程序开发wx.uploadFile和wx.request的返回值格式不同
  • 如何找工作 校招 | 社招 | 秋招 | 春招 | 提前批
  • Docker Compose部署Kafka集群并在宿主机Windows连接开发
  • 对AAC解码的理解
  • Linux C++ 054-设计模式之外观模式
  • leetcode日记(38)字母异位词分组
  • C++数组
  • 【密码学】消息认证
  • 30秒的PHP代码片段(1)数组 - Array
  • iOS帅气加载动画、通知视图、红包助手、引导页、导航栏、朋友圈、小游戏等效果源码...
  • JavaScript对象详解
  • java正则表式的使用
  • Java知识点总结(JavaIO-打印流)
  • js算法-归并排序(merge_sort)
  • PHP面试之三:MySQL数据库
  • Python十分钟制作属于你自己的个性logo
  • Vue小说阅读器(仿追书神器)
  • Vue组件定义
  • 关于 Cirru Editor 存储格式
  • 函数式编程与面向对象编程[4]:Scala的类型关联Type Alias
  • 跨域
  • 聊聊flink的TableFactory
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 前端设计模式
  • 思否第一天
  • 通过npm或yarn自动生成vue组件
  • 微信小程序上拉加载:onReachBottom详解+设置触发距离
  • 为什么要用IPython/Jupyter?
  • 再次简单明了总结flex布局,一看就懂...
  • gunicorn工作原理
  • ​flutter 代码混淆
  • ​埃文科技受邀出席2024 “数据要素×”生态大会​
  • ​软考-高级-系统架构设计师教程(清华第2版)【第1章-绪论-思维导图】​
  • ​十个常见的 Python 脚本 (详细介绍 + 代码举例)
  • ​一些不规范的GTID使用场景
  • #stm32整理(一)flash读写
  • #如何使用 Qt 5.6 在 Android 上启用 NFC
  • #我与虚拟机的故事#连载20:周志明虚拟机第 3 版:到底值不值得买?
  • (04)odoo视图操作
  • (145)光线追踪距离场柔和阴影
  • (Bean工厂的后处理器入门)学习Spring的第七天
  • (bean配置类的注解开发)学习Spring的第十三天
  • (C语言)二分查找 超详细
  • (二)测试工具
  • (二)丶RabbitMQ的六大核心
  • (接上一篇)前端弄一个变量实现点击次数在前端页面实时更新
  • (六)Hibernate的二级缓存
  • (收藏)Git和Repo扫盲——如何取得Android源代码