当前位置: 首页 > news >正文

二十八、【人工智能】【机器学习】【PyTorch】- 手写体识别

目录

引言

PyTorch简介

深度学习与手写体识别

实现手写体识别的PyTorch模型

数据预处理

构建模型

训练模型

评估模型

最新进展

代码实现

步骤1: 导入必要的库

步骤2: 准备数据集

步骤3: 定义CNN模型

步骤4: 定义损失函数和优化器

步骤5: 训练模型

步骤6: 测试模型

完整代码

结论


引言

在过去的几十年里,手写体识别一直是计算机视觉和模式识别领域的重要课题。随着深度学习技术的兴起,特别是卷积神经网络(Convolutional Neural Networks, CNNs)的发展,我们已经能够以前所未有的精度和效率识别手写字符。本文将深入探讨如何使用PyTorch这一强大的深度学习框架,实现手写体识别,并介绍一些最新的技术进步。

PyTorch简介

PyTorch是由Facebook的人工智能研究实验室开发的一个开源机器学习库。它提供了动态计算图,使得构建和调整复杂的深度学习模型变得直观而高效。PyTorch的灵活性和易用性使其成为学术界和工业界广泛使用的工具之一。

深度学习与手写体识别

手写体识别的传统方法依赖于特征工程和基于规则的系统,但这些方法往往无法处理手写体的多样性和复杂性。相比之下,深度学习模型,尤其是CNNs,能够自动学习和提取图像中的特征,无需显式的人工特征设计。这使得它们在手写体识别任务上取得了显著的成功。

实现手写体识别的PyTorch模型

我们将使用经典的MNIST数据集作为案例研究,这是一个包含60,000个训练样本和10,000个测试样本的手写数字数据集。下面是如何使用PyTorch构建一个基本的CNN模型的步骤:

数据预处理

  • 加载MNIST数据集:使用torchvision.datasets.MNIST加载并分割训练和测试数据。
  • 数据转换:使用transforms对图像进行归一化和张量化处理。

构建模型

  • 定义CNN架构:包括卷积层、池化层和全连接层。
  • 初始化模型:创建模型实例并选择合适的设备(CPU或GPU)。

训练模型

  • 设置训练参数:如学习率、优化器、损失函数。
  • 训练循环:遍历数据集,前向传播,计算损失,反向传播,更新权重。

评估模型

  • 测试模型:在测试集上评估模型的性能。
  • 分析结果:查看混淆矩阵,评估分类准确率。

最新进展

近年来,手写体识别领域的一些最新进展包括:

  • 注意力机制:引入注意力机制可以增强模型在局部区域的聚焦能力,提高识别准确性。
  • 数据增强:通过旋转、缩放和剪切等操作增加训练集多样性,有助于提高模型的泛化能力。
  • 迁移学习:利用在大型数据集上预训练的模型,通过微调适应手写体识别任务,可以节省时间和计算资源。

代码实现

pytorch实现手写字体的识别。本算法最终识别率在97.76左右。

步骤1: 导入必要的库

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

步骤2: 准备数据集

# 读取文件
filename = r"F:\BaiduNetdiskDownload\mnist_train.csv"
# 使用 loadtxt 读取文件,忽略第一列(标签),并将剩余列转换为整数
train_features = (np.loadtxt(filename, delimiter=',', skiprows=0, usecols=range(1, 785), dtype=int) / 255.0)
train_labels = np.loadtxt(filename, delimiter=',', usecols=(0,), dtype=int)# 加载训练和测试数据
# train_features, train_labels = load_data_from_excel(r'F:\BaiduNetdiskDownload\mnist_train.csv')
# test_features, test_labels = load_data_from_excel(r'F:\BaiduNetdiskDownload\mnist_test.csv')
# 读取文件
fileTestname = r"F:\BaiduNetdiskDownload\mnist_test.csv"
# 使用 loadtxt 读取文件,忽略第一列(标签),并将剩余列转换为整数
test_features = (np.loadtxt(fileTestname, delimiter=',', skiprows=0, usecols=range(1, 785), dtype=int) / 255.0)
test_labels = np.loadtxt(fileTestname, delimiter=',', usecols=(0,), dtype=int)
# 转换为PyTorch的Tensor
train_features = torch.from_numpy(train_features).float()
train_labels = torch.from_numpy(train_labels).long()
test_features = torch.from_numpy(test_features).float()
test_labels = torch.from_numpy(test_labels).long()# 自定义Dataset类
class ExcelDataset(Dataset):def __init__(self, features, labels):self.features = featuresself.labels = labelsdef __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]# 创建数据集实例
train_dataset = ExcelDataset(train_features, train_labels)
test_dataset = ExcelDataset(test_features, test_labels)# 创建DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

步骤3: 定义CNN模型

# 定义模型
class MnistModel(nn.Module):def __init__(self):super(MnistModel, self).__init__()self.fc = nn.Sequential(nn.Linear(784, HIDDEN_SIZE_1),nn.ReLU(),nn.Linear(HIDDEN_SIZE_1, HIDDEN_SIZE_2),nn.ReLU(),nn.Linear(HIDDEN_SIZE_2, 32),nn.ReLU(),nn.Linear(32, 10),)def forward(self, x):x = x.view(x.size(0), -1)return self.fc(x)

步骤4: 定义损失函数和优化器

model = MnistModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

步骤5: 训练模型

# 训练模型
for epoch in range(NUM_EPOCHS):for i, (images, labels) in enumerate(train_loader):# 清零优化器中累积的梯度optimizer.zero_grad()# 构建训练模型outputs = model(images)# 计算损失函数loss = criterion(outputs, labels)# 启动反向传播过程loss.backward()# 使用优化算法来更新模型的参数optimizer.step()if (i + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{NUM_EPOCHS}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), MODEL_PATH)

步骤6: 测试模型

# 加载模型进行测试
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
#
# # 测试模型
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy: {} %'.format(100 * correct / total))

完整代码

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim# 定义超参数
BATCH_SIZE = 64
NUM_EPOCHS = 30
LEARNING_RATE = 0.0009
HIDDEN_SIZE_1 = 128
HIDDEN_SIZE_2 = 64
MODEL_PATH = r'F:\ai\moudle.ckpt'# 读取文件
filename = r"F:\BaiduNetdiskDownload\mnist_train.csv"
# 使用 loadtxt 读取文件,忽略第一列(标签),并将剩余列转换为整数
train_features = (np.loadtxt(filename, delimiter=',', skiprows=0, usecols=range(1, 785), dtype=int) / 255.0)
train_labels = np.loadtxt(filename, delimiter=',', usecols=(0,), dtype=int)# 加载训练和测试数据
# train_features, train_labels = load_data_from_excel(r'F:\BaiduNetdiskDownload\mnist_train.csv')
# test_features, test_labels = load_data_from_excel(r'F:\BaiduNetdiskDownload\mnist_test.csv')
# 读取文件
fileTestname = r"F:\BaiduNetdiskDownload\mnist_test.csv"
# 使用 loadtxt 读取文件,忽略第一列(标签),并将剩余列转换为整数
test_features = (np.loadtxt(fileTestname, delimiter=',', skiprows=0, usecols=range(1, 785), dtype=int) / 255.0)
test_labels = np.loadtxt(fileTestname, delimiter=',', usecols=(0,), dtype=int)
# 转换为PyTorch的Tensor
train_features = torch.from_numpy(train_features).float()
train_labels = torch.from_numpy(train_labels).long()
test_features = torch.from_numpy(test_features).float()
test_labels = torch.from_numpy(test_labels).long()# 自定义Dataset类
class ExcelDataset(Dataset):def __init__(self, features, labels):self.features = featuresself.labels = labelsdef __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]# 创建数据集实例
train_dataset = ExcelDataset(train_features, train_labels)
test_dataset = ExcelDataset(test_features, test_labels)# 创建DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)# 定义模型
class MnistModel(nn.Module):def __init__(self):super(MnistModel, self).__init__()self.fc = nn.Sequential(nn.Linear(784, HIDDEN_SIZE_1),nn.ReLU(),nn.Linear(HIDDEN_SIZE_1, HIDDEN_SIZE_2),nn.ReLU(),nn.Linear(HIDDEN_SIZE_2, 32),nn.ReLU(),nn.Linear(32, 10),)def forward(self, x):x = x.view(x.size(0), -1)return self.fc(x)model = MnistModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)# 训练模型
for epoch in range(NUM_EPOCHS):for i, (images, labels) in enumerate(train_loader):# 清零优化器中累积的梯度optimizer.zero_grad()# 构建训练模型outputs = model(images)# 计算损失函数loss = criterion(outputs, labels)# 启动反向传播过程loss.backward()# 使用优化算法来更新模型的参数optimizer.step()if (i + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{NUM_EPOCHS}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), MODEL_PATH)# 加载模型进行测试
# model.load_state_dict(torch.load(MODEL_PATH))
# model.eval()
#
# # 测试模型
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
#     print('Test Accuracy: {} %'.format(100 * correct / total))

结论

手写体识别是深度学习技术应用的一个生动例子,展示了AI在理解和解析人类创造的内容方面的能力。随着算法和硬件的进步,我们可以期待未来在手写体识别和其他相关领域看到更多令人兴奋的成果。

需要训练集的同学可以访问以下链接获取:

链接:https://pan.baidu.com/s/1afPQFahKy9Ei8IjNk8o8pw?pwd=so5x 
提取码:so5x

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 下一个排列
  • FFmpeg有理数相关的源码:AVRational结构体和其相关的函数分析
  • 英伟达显卡查看占用情况
  • 设计模式实战:报表生成系统的设计与实现
  • Chapter 22 数据可视化——折线图
  • Chapter 26 Python魔术方法
  • 用phpstudy搭建MySQL数据库
  • WebKit 的简介及工作流程
  • 科普文:JUC系列之多线程门闩同步器CountDownLatch的使用和源码
  • C++STL专题-string类
  • 低代码: 技术实现概述
  • 部署k8s+conatinerd环境
  • 【学习笔记】后缀自动机(SAM)
  • 【MySQL】索引——索引的引入、认识磁盘、磁盘的组成、扇区、磁盘访问、磁盘和MySQL交互、索引的概念
  • 微信小程序 - 自定义计数器 - 优化(键盘输入校验)
  • 《网管员必读——网络组建》(第2版)电子课件下载
  • CentOS7简单部署NFS
  • GitUp, 你不可错过的秀外慧中的git工具
  • Idea+maven+scala构建包并在spark on yarn 运行
  • js继承的实现方法
  • js中的正则表达式入门
  • React中的“虫洞”——Context
  • Windows Containers 大冒险: 容器网络
  • windows-nginx-https-本地配置
  • windows下mongoDB的环境配置
  • 工作中总结前端开发流程--vue项目
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 前端每日实战:70# 视频演示如何用纯 CSS 创作一只徘徊的果冻怪兽
  • 前端面试之闭包
  • 设计模式(12)迭代器模式(讲解+应用)
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 通过npm或yarn自动生成vue组件
  • ​configparser --- 配置文件解析器​
  • ​iOS实时查看App运行日志
  • # 20155222 2016-2017-2 《Java程序设计》第5周学习总结
  • ## 基础知识
  • ###STL(标准模板库)
  • #stm32整理(一)flash读写
  • #我与Java虚拟机的故事#连载08:书读百遍其义自见
  • (2)STL算法之元素计数
  • (a /b)*c的值
  • (done) 两个矩阵 “相似” 是什么意思?
  • (补充)IDEA项目结构
  • (代码示例)使用setTimeout来延迟加载JS脚本文件
  • (非本人原创)我们工作到底是为了什么?​——HP大中华区总裁孙振耀退休感言(r4笔记第60天)...
  • (附源码)ssm本科教学合格评估管理系统 毕业设计 180916
  • (附源码)ssm学生管理系统 毕业设计 141543
  • (附源码)计算机毕业设计SSM智能化管理的仓库管理
  • (没学懂,待填坑)【动态规划】数位动态规划
  • (三十五)大数据实战——Superset可视化平台搭建
  • (十)c52学习之旅-定时器实验
  • (一) storm的集群安装与配置
  • (一)Dubbo快速入门、介绍、使用
  • (原創) 系統分析和系統設計有什麼差別? (OO)
  • (转)Oracle存储过程编写经验和优化措施