当前位置: 首页 > news >正文

PyTorch VGG16手写数字识别教程

手写数字识别教程:使用PyTorch和VGG16

1. 环境准备

确保你已安装以下库:

pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理

我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像大小调整为224x224transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化处理,均值和标准差
])# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型

VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。

class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 定义卷积层self.vgg = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 将输入通道设置为1(灰度图)nn.ReLU(),  # 激活函数nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层,减小特征图尺寸nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 第一个全连接层nn.ReLU(),nn.Dropout(),  # 随机失活,防止过拟合nn.Linear(4096, 4096),  # 第二个全连接层nn.ReLU(),nn.Dropout(),nn.Linear(4096, 10)  # 输出层,10个类(数字0-9))def forward(self, x):x = self.vgg(x)  # 通过卷积层x = x.view(x.size(0), -1)  # 展平特征图x = self.classifier(x)  # 通过全连接层return x
5. 训练模型

我们将使用交叉熵损失函数和Adam优化器,并训练模型。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测可用的设备
model = VGG16().to(device)  # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练循环
for epoch in range(5):  # 训练5个epochmodel.train()  # 设置为训练模式for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')  # 输出当前epoch的损失
6. 测试模型

在测试阶段,我们将计算模型的准确率。

model.eval()  # 设置为评估模式
with torch.no_grad():  # 禁用梯度计算correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备outputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计正确预测的数量print(f'Accuracy: {100 * correct / total:.2f}%')  # 输出准确率

总结

这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!

相关文章:

  • pick你的第一个人形机器人——青龙强化学习环境测试
  • 数字安全二之密钥结合消息摘要
  • 利用Java easyExcel库实现高效Excel数据处理
  • 基于RealSense D435相机实现手部姿态重定向
  • 水仙花数求解-C语言
  • 另外知识与网络总结
  • 5V继电器模块详解(STM32)
  • 多IP站群服务器对SEO优化的几大好处
  • 算法打卡:第十一章 图论part08
  • 在C#中使用JSON
  • 【test】google cloud
  • Vxe UI vue vxe-table vxe-grid 单元格与表尾单元格如何格式化数据
  • 微服务--初识MQ
  • 车辆重识别(去噪扩散概率模型)论文阅读2024/9/27
  • centos7 yum 更新 nginx 到最新版本 1.26
  • [js高手之路]搞清楚面向对象,必须要理解对象在创建过程中的内存表示
  • chrome扩展demo1-小时钟
  • express.js的介绍及使用
  • Java IO学习笔记一
  • JAVA_NIO系列——Channel和Buffer详解
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • Octave 入门
  • react-core-image-upload 一款轻量级图片上传裁剪插件
  • Vue ES6 Jade Scss Webpack Gulp
  • yii2中session跨域名的问题
  • 初探 Vue 生命周期和钩子函数
  • 给初学者:JavaScript 中数组操作注意点
  • 基于Volley网络库实现加载多种网络图片(包括GIF动态图片、圆形图片、普通图片)...
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 如何正确配置 Ubuntu 14.04 服务器?
  • 我感觉这是史上最牛的防sql注入方法类
  • 小程序 setData 学问多
  • 要让cordova项目适配iphoneX + ios11.4,总共要几步?三步
  • 移动端唤起键盘时取消position:fixed定位
  • 7行Python代码的人脸识别
  • ​2020 年大前端技术趋势解读
  • # 达梦数据库知识点
  • (LeetCode) T14. Longest Common Prefix
  • (rabbitmq的高级特性)消息可靠性
  • (附源码)计算机毕业设计ssm-Java网名推荐系统
  • (附源码)计算机毕业设计SSM智能化管理的仓库管理
  • (黑马出品_高级篇_01)SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式
  • (三)docker:Dockerfile构建容器运行jar包
  • (四)模仿学习-完成后台管理页面查询
  • (转)【Hibernate总结系列】使用举例
  • ******IT公司面试题汇总+优秀技术博客汇总
  • *setTimeout实现text输入在用户停顿时才调用事件!*
  • .CSS-hover 的解释
  • .NET 中各种混淆(Obfuscation)的含义、原理、实际效果和不同级别的差异(使用 SmartAssembly)
  • .NET企业级应用架构设计系列之结尾篇
  • .NET性能优化(文摘)
  • @EnableWebSecurity 注解的用途及适用场景
  • @Transactional 参数详解
  • [ vulhub漏洞复现篇 ] Apache Flink目录遍历(CVE-2020-17519)
  • [23] GaussianAvatars: Photorealistic Head Avatars with Rigged 3D Gaussians