当前位置: 首页 > news >正文

《深度学习》PyTorch 手写数字识别 案例解析及实现 <上>

目录

一、了解MINIST数据集

1、什么是MINIST

2、查看MINIST由来

二、实操代码

1、下载训练数据集

2、下载测试数据集

运行结果:

3、展示手写数字图片

运行结果:

4、打包图片

运行结果:

5、判断当前pytorch使用的设备

1)torch.cuda.is_available()

2)torch.backends.mps.is_available()

3)MPS

运行结果:


一、了解MINIST数据集

1、什么是MINIST

        MINIST是一种基于神经网络的手写数字识别算法。它是LeCun等人在1998年提出的,是深度学习领域的里程碑之一。MINIST数据集包含了大量的手写数字图片,MINIST算法通过训练神经网络,可以有效地识别这些手写数字。MINIST算法在计算机视觉和模式识别中有广泛的应用,被认为是机器学习领域的经典问题之一。

        MNIST包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试

        所有的图像都是灰度的,大小为28x28像素的,并且居中的,以减少预处理和加快运行。

2、查看MINIST由来

        进入下列网页,即可查看

https://yann.lecun.com/exdb/mnist/icon-default.png?t=O83Ahttps://yann.lecun.com/exdb/mnist/        打卡即可得到下列画面:

此时可知道这个MINIST数据集中训练集和测试集所占大小等等:

二、实操代码

1、下载训练数据集

训练数据集包含训练用的手写数字图片及其对应的标签

import torch 
print(torch.__version__)   # 查看torch版本号from torch import nn  # 导入神经网络模块,提供了构建网络所需的各种层
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,它可以将数据集封装成适合批处理的数据加载器。
from torchvision import datasets   # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor   一个数据转换操作,用于将图片数据转换为PyTorch张量(Tensor)。PyTorch中的模型只能接受张量作为输入。training_data = datasets.MNIST(   # 跳转到函数的内部源代码,pycharm 按下ctrl +鼠标点击root='data',  # 指定数据集下载后储存的根目录train=True,  # 表示下载的是训练集,如需下载测试集则更改为False即可download=True,   # 表示如果本地没有数据集,则自动下载,有则不再下载transform=ToTensor()   # 指定一个数据转换操作,即将下载的图片转换为pytorch张量tensor,因为pytorch模型只能处理张量类型的数据
)

        将代码和下列测试集代码一起运行。

2、下载测试数据集

        只需将训练数据集中的train参数结果更改为False

test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()  
)  # NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。print(len(training_data))   # 打印训练集数据条数
print(len(test_data))   # 打印测试集数据条数
运行结果:

3、展示手写数字图片

from matplotlib import pyplot as plt   # 导入绘图库
figure = plt.figure()   # 设置一个空白画布
for i in range(9):img,label = training_data[i+59000]   # 提取第59000张图片开始,共9张,返回图片及其对应的标签值figure.add_subplot(3,3,i+1)   # 在画布创建3行3列的小窗口,通过遍历的值i来确定每个画布展示的图片plt.title(label)   # 设置每个窗口的标题,设置标签为上述返回的标签值plt.axis('off')   # 取消画布中的坐标轴的图像plt.imshow(img.squeeze(),cmap='gray')   # plt.imshow()将NumPy数组data中的数据显示为图像,并在图形窗口中,a = img.squeeze()   # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1,则张量不会改变。
plt.show()

        最后一步img.squeeze降低维度是因为遍历出来的图像有一个冗余的维度没有用,如下所示,维度为1,图像大小为28x28像素的。

运行结果:

4、打包图片

train_dataloader = DataLoader(training_data,batch_size=64)  # 调用上述定义的DataLoader打包库,将训练集的图片和标签,64张图片为一个包,
test_dataloader = DataLoader(test_data,batch_size=64)   # 将测试集的图片和标签,每64张打包成一份
for x,y in test_dataloader:# x是表示打包好的每一个数据包,其形状为[64,1,28,28],64表示批次大小,1表示通道数为1,即灰度图,28表示图像的宽高像素值# y表示每个图片标签print(f"shape of x[N,C,H,W]:{x.shape}")   # 打印图片形状print(f"shape of y:{y.shape}{y.dtype}")   # 打印标签的形状和数据类型break  # 跳出并终止循环,表示只遍历一个包的数据情况
运行结果:

5、判断当前pytorch使用的设备

"""判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""  # 返回cuda,mps,cpu,
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化。CUDA驱动软件的功能:pytorch能够去执行cuda的命令,cuda通过GPU指令集
# 神经网络的模型也需要传入到GPU,1个batchsize的数据集也需要传入到GPU,才可以进行训练。
        1)torch.cuda.is_available()

                 检查CUDA是否在当前系统上可用。CUDA是NVIDIA的并行计算平台和编程模型,它允许软件利用NVIDIA图形处理单元(GPU)进行加速计算。如果CUDA可用,这意味着你的系统有NVIDIA GPU,并且PyTorch已经配置为可以使用CUDA。

        2)torch.backends.mps.is_available()

                检查MPS是否可用。请注意,这个检查通常只在Apple Silicon Macs上返回True

        3)MPS

                MPS是Apple提供的一套高性能图形和计算框架,专门设计用于Apple Silicon Macs上的Metal API。虽然MPS不直接对应于PyTorch的CUDA,但PyTorch从1.8版本开始增加了对Apple Silicon Macs的支持,通过MPS后端进行加速。

        表示如果torch.cuda.is_available()返回的是True则返回cuda,即当前使用的设备是cuda,如果返回False即执行下面的判断语句,即如果torch.backends.mps.is_available()返回的是True则返回mps,即当前使用的是苹果设备的mps,反之则使用的是cpu设备来计算。

运行结果:

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 浙大数据结构:02-线性结构3 Reversing Linked List
  • RFM模型
  • 数字证书学习
  • Docker 部署 Seata (图文并茂超详细)
  • Python数据处理利器,pivot与melt让表格变得灵活
  • Java架构师未来篇大模型
  • c++ 链表详细介绍
  • C++vector类 (带你一篇文章搞定C++中的vector类)
  • 区块链审计 如何测试solidity的bool值占用几个字节
  • 基于SpringBoot+Vue+MySQL的画师约稿平台系统
  • 【Unity-Lua】音乐播放器循环滚动播放音乐名
  • 【微服务】Ribbon(负载均衡,服务调用)+ OpenFeign(服务发现,远程调用)【详解】
  • 【Kubernetes】常见面试题汇总(二)
  • JVM: JDK内置命令 - JPS
  • 微信小程序-formData使用
  • 《Java8实战》-第四章读书笔记(引入流Stream)
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • MySQL-事务管理(基础)
  • PhantomJS 安装
  • React Native移动开发实战-3-实现页面间的数据传递
  • vuex 笔记整理
  • Vue--数据传输
  • 浅析微信支付:申请退款、退款回调接口、查询退款
  • 思否第一天
  • 微信小程序设置上一页数据
  • 一起来学SpringBoot | 第十篇:使用Spring Cache集成Redis
  • 用 Swift 编写面向协议的视图
  • #控制台大学课堂点名问题_课堂随机点名
  • $.extend({},旧的,新的);合并对象,后面的覆盖前面的
  • (6) 深入探索Python-Pandas库的核心数据结构:DataFrame全面解析
  • (7)STL算法之交换赋值
  • (9)目标检测_SSD的原理
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第2节(共同的基类)
  • (html5)在移动端input输入搜索项后 输入法下面为什么不想百度那样出现前往? 而我的出现的是换行...
  • (Java数据结构)ArrayList
  • (solr系列:一)使用tomcat部署solr服务
  • (web自动化测试+python)1
  • (八)五种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (补充):java各种进制、原码、反码、补码和文本、图像、音频在计算机中的存储方式
  • (草履虫都可以看懂的)PyQt子窗口向主窗口传递参数,主窗口接收子窗口信号、参数。
  • (二)PySpark3:SparkSQL编程
  • (二)windows配置JDK环境
  • (附源码)springboot“微印象”在线打印预约系统 毕业设计 061642
  • (附源码)springboot社区居家养老互助服务管理平台 毕业设计 062027
  • (附源码)ssm捐赠救助系统 毕业设计 060945
  • (六)DockerCompose安装与配置
  • (实测可用)(3)Git的使用——RT Thread Stdio添加的软件包,github与gitee冲突造成无法上传文件到gitee
  • (一)、python程序--模拟电脑鼠走迷宫
  • (杂交版)植物大战僵尸
  • (转)Linux下编译安装log4cxx
  • (转)从零实现3D图像引擎:(8)参数化直线与3D平面函数库
  • .NET HttpWebRequest、WebClient、HttpClient
  • .NET Remoting Basic(10)-创建不同宿主的客户端与服务器端
  • .net wcf memory gates checking failed
  • .NET 应用启用与禁用自动生成绑定重定向 (bindingRedirect),解决不同版本 dll 的依赖问题