当前位置: 首页 > news >正文

Pytorch-08 实战:手写数字识别

手写数字识别项目在机器学习中经常被用作入门练习,因为它相对简单,但又涵盖了许多基本的概念。这个项目可以视为机器学习中的 “Hello World”,因为它涉及到数据收集、特征提取、模型选择、训练和评估等机器学习中的基本步骤,所以手写数字识别项目是一个很好的起点。

我们的要做的是,训练出一个人工神经网络,使它能够识别手写数字(如下图所示):

以下是一个简单的示例代码,展示如何使用PyTorch创建一个手写数字识别的模型,包括数据集加载、训练和测试过程。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f"训练集第1张图像形状 = {train_dataset.__getitem__(0)[0].shape}")
print(f"训练集第1张图像标签 = {train_dataset.__getitem__(0)[1]}")
print(f"测试集第1张图像形状 = {test_dataset.__getitem__(0)[0].shape}")
print(f"测试集第1张图像标签 = {test_dataset.__getitem__(0)[1]}")# 使用数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义神经网络模型并将其移至GPU
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Sequential(nn.Linear(28*28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10))def forward(self, x):x = x.view(x.size(0), -1)x = self.fc(x)return xmodel = Net().to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型,训练过程输出损失值
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoptimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 测试模型,输出数字识别准确率
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on the test set: {100 * correct / total}%')

程序运行后,输出如下:

训练集第1张图像形状 = torch.Size([1, 28, 28])
训练集第1张图像标签 = 5
测试集第1张图像形状 = torch.Size([1, 28, 28])
测试集第1张图像标签 = 7
Epoch [1/5], Loss: 0.3935443162918091
Epoch [2/5], Loss: 0.1757822483778
Epoch [3/5], Loss: 0.1337398886680603
Epoch [4/5], Loss: 0.03868262842297554
Epoch [5/5], Loss: 0.025882571935653687
Accuracy on the test set: 96.85%进程已结束,退出代码为 0

相关文章:

  • 力扣刷题---3146. 两个字符串的排列差
  • 开源内网穿透神器:中微子代理(neutrino-proxy)实现内网穿刺
  • python毕设项目选题汇总(全)
  • 27寸2K显示器 - HKC G27H2
  • ELK 日志监控平台(一)- 快速搭建
  • springboot 两个相同类型的Bean使用@Resouce加载
  • 数据库工具类
  • CHI dataless 传输——CHI(4)
  • 【图像超分】论文精读:Residual Non-local Attention Networks for Image Restoration(RNAN)
  • netty-socketio 集群随记
  • 如何在cPanel面板中开启盗链保护
  • 瑞芯微RV1126——人脸识别框架分析
  • Go语言的命名规范是怎样的?
  • 【数据结构】数据结构中的隐藏玩法——栈与队列
  • BTC系列-系统学习铭文(二)-序数理论
  • 时间复杂度分析经典问题——最大子序列和
  • ➹使用webpack配置多页面应用(MPA)
  • 4个实用的微服务测试策略
  • 5分钟即可掌握的前端高效利器:JavaScript 策略模式
  • canvas绘制圆角头像
  • classpath对获取配置文件的影响
  • JavaScript对象详解
  • Mysql优化
  • 从@property说起(二)当我们写下@property (nonatomic, weak) id obj时,我们究竟写了什么...
  • 从0到1:PostCSS 插件开发最佳实践
  • 飞驰在Mesos的涡轮引擎上
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 聊聊sentinel的DegradeSlot
  • 微信小程序设置上一页数据
  • 我从编程教室毕业
  • 译自由幺半群
  • ​【C语言】长篇详解,字符系列篇3-----strstr,strtok,strerror字符串函数的使用【图文详解​】
  • ​configparser --- 配置文件解析器​
  • #每天一道面试题# 什么是MySQL的回表查询
  • (3)选择元素——(14)接触DOM元素(Accessing DOM elements)
  • (Java实习生)每日10道面试题打卡——JavaWeb篇
  • (动态规划)5. 最长回文子串 java解决
  • (附源码)spring boot智能服药提醒app 毕业设计 102151
  • (附源码)ssm高校志愿者服务系统 毕业设计 011648
  • (离散数学)逻辑连接词
  • (力扣)1314.矩阵区域和
  • (一) springboot详细介绍
  • (译) 理解 Elixir 中的宏 Macro, 第四部分:深入化
  • (转)大型网站的系统架构
  • **PHP分步表单提交思路(分页表单提交)
  • .NET 的静态构造函数是否线程安全?答案是肯定的!
  • .NET 中让 Task 支持带超时的异步等待
  • .NET/C# 使用 ConditionalWeakTable 附加字段(CLR 版本的附加属性,也可用用来当作弱引用字典 WeakDictionary)
  • .NET设计模式(8):适配器模式(Adapter Pattern)
  • .pub是什么文件_Rust 模块和文件 - 「译」
  • [.net 面向对象程序设计进阶] (19) 异步(Asynchronous) 使用异步创建快速响应和可伸缩性的应用程序...
  • [Android Studio] 开发Java 程序
  • [Android]RecyclerView添加HeaderView出现宽度问题
  • [Android]使用Retrofit进行网络请求
  • [ARC066F]Contest with Drinks Hard