当前位置: 首页 > news >正文

【深度学习】温故而知新4-手写体识别-多层感知机+CNN网络-完整代码-可运行

多层感知机版本

import torch
import torch.nn as nn
import numpy as np
import torch.utils
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib
import os
# 前置配置:
matplotlib.use('Agg')
class Config():base_dir = os.path.dirname(os.path.abspath(__file__))device = "cuda" if torch.cuda.is_available() else "cpu"# 超参配置: batch_size=128lr=0.0001
# 数据集初步加工
train_ds = torchvision.datasets.MNIST(os.path.join(Config.base_dir,"data"),train=True,download=False,transform=transforms.ToTensor())
test_ds = torchvision.datasets.MNIST(os.path.join(Config.base_dir,"data"),train=False,download=False,transform=transforms.ToTensor())
# 生成dataLoader
train_dl = DataLoader(train_ds,batch_size=Config.batch_size,shuffle=True)
test_dl = DataLoader(test_ds,batch_size=Config.batch_size)def show_pic_and_label():# 查看dataloaderprint(len(train_dl.dataset))# 查看 它的img 和 labelimgs, labels = next(iter(train_dl))# print(imgs, labels)sample_img = imgs[0:10]sample_label = labels[0:10]print(sample_img,sample_label)for idx,npimg in enumerate(sample_img,1):# plt.subplot()# 也可以挤一挤npimg = npimg.squeeze()# npimg = npimg.reshape(28,28)plt.subplot(1,10,idx)plt.imshow(npimg)plt.axis('off')plt.savefig(os.path.join(Config.base_dir,"1.jpg"))print(sample_label)
# 构建模型 
class Model(nn.Module):def __init__(self):super().__init__()# 第一层 28*28, 120self.liner1 = nn.Linear(28*28,120)# 第二层 输出84self.liner2 = nn.Linear(120, 84)# 第三层 输出10self.liner3 = nn.Linear(84,10)def forward(self, input):x = input.view(-1,28*28)# @todo 这里踩坑了,不是nn.ReLU, 而是torch.ReLux = torch.relu(self.liner1(x))x = torch.relu(self.liner2(x))x = self.liner3(x)return xmodel = Model().to(Config.device)
# print(model)
optim = torch.optim.Adam(model.parameters(), lr = Config.lr)loss_fn = nn.CrossEntropyLoss()def model_test():"""确认输入输出是没问题的。"""res = model(torch.randn(10,28*28).to(Config.device))print(res.shape)print(res)
def accuracy(y_pred,y_true):y_pred = (torch.argmax(y_pred,dim=1) == y_true).type(torch.int64)return y_pred.sum()
# 编写训练过程
def train(dataloader, model, loss_fn, optimizer):total_row_count = len(dataloader.dataset)total_batch_count = len(dataloader)total_acc = 0total_loss = 0for X,y in dataloader:X,y = X.to(Config.device),y.to(Config.device)y_pred = model(X)acc = accuracy(y_pred,y)loss = loss_fn(y_pred,y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():total_acc+=acctotal_loss+=losstotal_acc = total_acc/total_row_counttotal_loss = total_loss/total_batch_countreturn total_loss, total_acc# 编写测试过程
def test(dataloader, model, loss_fn):total_row_count = len(dataloader.dataset)total_batch_count = len(dataloader)total_acc = 0total_loss = 0with torch.no_grad():for X,y in dataloader:X,y = X.to(Config.device),y.to(Config.device)y_pred = model(X)acc = accuracy(y_pred,y)loss = loss_fn(y_pred,y)total_acc+=acctotal_loss+=losstotal_acc = total_acc/total_row_counttotal_loss = total_loss/total_batch_countreturn total_loss, total_accepochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):epoch_loss, epoch_acc = train(train_dl,model,loss_fn,optim)epoch_test_loss, epoch_test_acc = test(test_dl,model,loss_fn)template = "epoch:{:2d}, train_loss:{:.5f}, train_acc:{:.1f}%, test_loss:{:.5f},test_acc:{:.1f}%"print(template.format(epoch, epoch_loss.data.item(), epoch_acc.data.item()*100, epoch_test_loss.data.item(), epoch_test_acc.data.item()*100))#print(epoch, epoch_loss.data.item(),epoch_acc.data.item(),epoch_test_loss.data.item(),epoch_test_acc.data.item())
if __name__ == "__main__":# model_test()pass# y_pred = torch.tensor([#     [1,2,3],#     [2,1,3],#     [3,2,1],#     ])# y_true = torch.tensor([2,0,1])# res = accuracy(y_pred,y_true)# print(res)
(pytorchbook) (base) justin@justin-System-Product-Name:~/Desktop/code/python_project/mypaper$ /home/justin/miniconda3/envs/pytorchbook/bin/python /home/justin/Desktop/code/python_project/mypaper/pytorchbook/chapter4/手写体识别.py
epoch: 0, train_loss:1.17435, train_acc:70.1%, test_loss:0.47829,test_acc:88.7%
epoch: 1, train_loss:0.39913, train_acc:89.5%, test_loss:0.33029,test_acc:91.0%
epoch: 2, train_loss:0.31837, train_acc:91.1%, test_loss:0.28821,test_acc:91.8%
epoch: 3, train_loss:0.28331, train_acc:92.0%, test_loss:0.26157,test_acc:92.5%
epoch: 4, train_loss:0.26049, train_acc:92.5%, test_loss:0.24704,test_acc:93.1%
epoch: 5, train_loss:0.24122, train_acc:93.1%, test_loss:0.22766,test_acc:93.4%
epoch: 6, train_loss:0.22516, train_acc:93.6%, test_loss:0.21446,test_acc:93.7%
epoch: 7, train_loss:0.21048, train_acc:94.0%, test_loss:0.20211,test_acc:94.2%
epoch: 8, train_loss:0.19786, train_acc:94.4%, test_loss:0.19200,test_acc:94.5%
epoch: 9, train_loss:0.18692, train_acc:94.6%, test_loss:0.18458,test_acc:94.7%
epoch:10, train_loss:0.17689, train_acc:95.0%, test_loss:0.17440,test_acc:94.9%
epoch:11, train_loss:0.16766, train_acc:95.2%, test_loss:0.16584,test_acc:95.0%
epoch:12, train_loss:0.15932, train_acc:95.5%, test_loss:0.16011,test_acc:95.3%
epoch:13, train_loss:0.15149, train_acc:95.7%, test_loss:0.15269,test_acc:95.5%
epoch:14, train_loss:0.14443, train_acc:95.9%, test_loss:0.14685,test_acc:95.5%
epoch:15, train_loss:0.13801, train_acc:96.0%, test_loss:0.14179,test_acc:95.7%
epoch:16, train_loss:0.13172, train_acc:96.2%, test_loss:0.13724,test_acc:95.8%
epoch:17, train_loss:0.12594, train_acc:96.3%, test_loss:0.13256,test_acc:96.1%
epoch:18, train_loss:0.12016, train_acc:96.5%, test_loss:0.13012,test_acc:96.1%
epoch:19, train_loss:0.11557, train_acc:96.7%, test_loss:0.12416,test_acc:96.2%
epoch:20, train_loss:0.11037, train_acc:96.8%, test_loss:0.12220,test_acc:96.4%
epoch:21, train_loss:0.10601, train_acc:97.0%, test_loss:0.11851,test_acc:96.5%
epoch:22, train_loss:0.10160, train_acc:97.1%, test_loss:0.11445,test_acc:96.6%
epoch:23, train_loss:0.09774, train_acc:97.2%, test_loss:0.11242,test_acc:96.5%
epoch:24, train_loss:0.09388, train_acc:97.3%, test_loss:0.10876,test_acc:96.6%
epoch:25, train_loss:0.09008, train_acc:97.4%, test_loss:0.10713,test_acc:96.7%
epoch:26, train_loss:0.08692, train_acc:97.5%, test_loss:0.10526,test_acc:96.7%
epoch:27, train_loss:0.08370, train_acc:97.6%, test_loss:0.10490,test_acc:96.8%
epoch:28, train_loss:0.08067, train_acc:97.7%, test_loss:0.10183,test_acc:96.8%
epoch:29, train_loss:0.07805, train_acc:97.7%, test_loss:0.10172,test_acc:96.9%
epoch:30, train_loss:0.07480, train_acc:97.8%, test_loss:0.09779,test_acc:97.0%
epoch:31, train_loss:0.07235, train_acc:97.8%, test_loss:0.09650,test_acc:97.0%
epoch:32, train_loss:0.06958, train_acc:98.0%, test_loss:0.09472,test_acc:97.1%
epoch:33, train_loss:0.06747, train_acc:98.0%, test_loss:0.09349,test_acc:97.1%
epoch:34, train_loss:0.06504, train_acc:98.1%, test_loss:0.09270,test_acc:97.1%
epoch:35, train_loss:0.06236, train_acc:98.2%, test_loss:0.09221,test_acc:97.2%
epoch:36, train_loss:0.06039, train_acc:98.3%, test_loss:0.09187,test_acc:97.2%
epoch:37, train_loss:0.05850, train_acc:98.3%, test_loss:0.08917,test_acc:97.3%
epoch:38, train_loss:0.05624, train_acc:98.4%, test_loss:0.08657,test_acc:97.3%
epoch:39, train_loss:0.05456, train_acc:98.4%, test_loss:0.08722,test_acc:97.4%
epoch:40, train_loss:0.05246, train_acc:98.5%, test_loss:0.08660,test_acc:97.4%
epoch:41, train_loss:0.05088, train_acc:98.5%, test_loss:0.08511,test_acc:97.4%
epoch:42, train_loss:0.04919, train_acc:98.6%, test_loss:0.08628,test_acc:97.4%
epoch:43, train_loss:0.04726, train_acc:98.7%, test_loss:0.08620,test_acc:97.4%
epoch:44, train_loss:0.04571, train_acc:98.7%, test_loss:0.08298,test_acc:97.5%
epoch:45, train_loss:0.04408, train_acc:98.8%, test_loss:0.08309,test_acc:97.5%
epoch:46, train_loss:0.04274, train_acc:98.8%, test_loss:0.08241,test_acc:97.5%
epoch:47, train_loss:0.04122, train_acc:98.9%, test_loss:0.08229,test_acc:97.6%
epoch:48, train_loss:0.03967, train_acc:98.9%, test_loss:0.08120,test_acc:97.6%
epoch:49, train_loss:0.03829, train_acc:99.0%, test_loss:0.08134,test_acc:97.5%

问题1:
epoch: 0, train_loss:1.17435, train_acc:70.1%, test_loss:0.47829,test_acc:88.7%
为什么第一轮训练train_acc要比test_acc掉点不少,是因为第一轮,是刚开始,train按批次比完了,才会到test。因此test是高
那么为什么其它轮,又是test比train低呢?
因为即使train是按批次的,但仍然有可能过拟合,契合的好。所以test是比不过的。

在这里插入图片描述
在这里插入图片描述

CNN版本

只需要将model换一下,其它的毛也不需要动

class Model(nn.Module):def __init__(self) -> None:super().__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5) # 1X28X28 --> 6X24X24 # 池化 6X12X12self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5) # 6X12X12--> 16X8X8# 池化 16X4X4 self.liner_1 = nn.Linear(16*4*4,256)self.liner_2 = nn.Linear(256,10)def forward(self,input):x = torch.max_pool2d(torch.relu(self.conv1(input)),2)x = torch.max_pool2d(torch.relu(self.conv2(x)),2)# 展平层x = x.view(-1, 16*4*4)x = torch.relu(self.liner_1(x))x = self.liner_2(x)return x# 这里是在学习一种调试的方式
class _Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)def forward(self, input):a1 = self.conv1(input)a2 = F.max_pool2d(a1,2)a3 = self.conv2(a2)a4 = F.max_pool2d(a3,2)# print()
epoch: 0, train_loss:1.13144, train_acc:74.3%, test_loss:0.36698,test_acc:90.6%
epoch: 1, train_loss:0.30213, train_acc:91.6%, test_loss:0.22672,test_acc:93.5%
epoch: 2, train_loss:0.21874, train_acc:93.7%, test_loss:0.17848,test_acc:94.9%
epoch: 3, train_loss:0.17849, train_acc:94.8%, test_loss:0.14941,test_acc:95.4%
epoch: 4, train_loss:0.15203, train_acc:95.5%, test_loss:0.12645,test_acc:96.2%
epoch: 5, train_loss:0.13339, train_acc:96.1%, test_loss:0.11351,test_acc:96.5%
epoch: 6, train_loss:0.11952, train_acc:96.5%, test_loss:0.09954,test_acc:96.9%
epoch: 7, train_loss:0.10876, train_acc:96.7%, test_loss:0.09198,test_acc:97.3%
epoch: 8, train_loss:0.09943, train_acc:97.1%, test_loss:0.08412,test_acc:97.3%
epoch: 9, train_loss:0.09255, train_acc:97.2%, test_loss:0.07788,test_acc:97.6%
epoch:10, train_loss:0.08576, train_acc:97.4%, test_loss:0.07551,test_acc:97.6%
epoch:11, train_loss:0.08089, train_acc:97.5%, test_loss:0.06757,test_acc:97.9%
epoch:12, train_loss:0.07635, train_acc:97.7%, test_loss:0.06399,test_acc:98.0%
epoch:13, train_loss:0.07175, train_acc:97.8%, test_loss:0.05942,test_acc:98.1%
epoch:14, train_loss:0.06862, train_acc:97.9%, test_loss:0.05657,test_acc:98.2%
epoch:15, train_loss:0.06509, train_acc:98.0%, test_loss:0.05776,test_acc:98.1%
epoch:16, train_loss:0.06273, train_acc:98.1%, test_loss:0.05381,test_acc:98.3%
epoch:17, train_loss:0.05940, train_acc:98.2%, test_loss:0.05134,test_acc:98.4%
epoch:18, train_loss:0.05681, train_acc:98.3%, test_loss:0.05330,test_acc:98.2%
epoch:19, train_loss:0.05434, train_acc:98.4%, test_loss:0.04689,test_acc:98.6%
epoch:20, train_loss:0.05175, train_acc:98.5%, test_loss:0.04500,test_acc:98.6%
epoch:21, train_loss:0.05027, train_acc:98.6%, test_loss:0.04645,test_acc:98.5%
epoch:22, train_loss:0.04849, train_acc:98.6%, test_loss:0.04274,test_acc:98.7%
epoch:23, train_loss:0.04600, train_acc:98.6%, test_loss:0.04739,test_acc:98.5%
epoch:24, train_loss:0.04449, train_acc:98.7%, test_loss:0.04360,test_acc:98.7%
epoch:25, train_loss:0.04359, train_acc:98.7%, test_loss:0.04198,test_acc:98.7%
epoch:26, train_loss:0.04115, train_acc:98.8%, test_loss:0.04209,test_acc:98.7%
epoch:27, train_loss:0.03978, train_acc:98.8%, test_loss:0.04147,test_acc:98.7%
epoch:28, train_loss:0.03866, train_acc:98.9%, test_loss:0.03845,test_acc:98.8%
epoch:29, train_loss:0.03721, train_acc:98.9%, test_loss:0.04142,test_acc:98.7%
epoch:30, train_loss:0.03632, train_acc:98.9%, test_loss:0.03916,test_acc:98.8%
epoch:31, train_loss:0.03525, train_acc:98.9%, test_loss:0.04137,test_acc:98.7%
epoch:32, train_loss:0.03364, train_acc:99.0%, test_loss:0.03829,test_acc:98.8%
epoch:33, train_loss:0.03323, train_acc:99.0%, test_loss:0.04090,test_acc:98.7%
epoch:34, train_loss:0.03179, train_acc:99.0%, test_loss:0.03660,test_acc:98.9%
epoch:35, train_loss:0.03125, train_acc:99.1%, test_loss:0.03698,test_acc:98.9%
epoch:36, train_loss:0.03009, train_acc:99.1%, test_loss:0.03624,test_acc:98.8%
epoch:37, train_loss:0.02958, train_acc:99.1%, test_loss:0.03525,test_acc:98.9%
epoch:38, train_loss:0.02902, train_acc:99.1%, test_loss:0.03705,test_acc:98.9%
epoch:39, train_loss:0.02789, train_acc:99.2%, test_loss:0.03579,test_acc:98.9%
epoch:40, train_loss:0.02741, train_acc:99.2%, test_loss:0.03896,test_acc:98.9%
epoch:41, train_loss:0.02604, train_acc:99.2%, test_loss:0.03572,test_acc:98.9%
epoch:42, train_loss:0.02518, train_acc:99.2%, test_loss:0.03741,test_acc:98.7%
epoch:43, train_loss:0.02471, train_acc:99.3%, test_loss:0.03319,test_acc:98.9%
epoch:44, train_loss:0.02413, train_acc:99.3%, test_loss:0.03753,test_acc:98.8%
epoch:45, train_loss:0.02340, train_acc:99.3%, test_loss:0.03333,test_acc:98.9%
epoch:46, train_loss:0.02272, train_acc:99.3%, test_loss:0.03303,test_acc:99.0%
epoch:47, train_loss:0.02188, train_acc:99.3%, test_loss:0.03451,test_acc:98.9%
epoch:48, train_loss:0.02169, train_acc:99.4%, test_loss:0.03433,test_acc:98.9%
epoch:49, train_loss:0.02068, train_acc:99.4%, test_loss:0.03331,test_acc:98.9%

对比一下 cnn的到了98.9,而mlp的只有97.x

函数式API的调用方式

import torch.nn.functional as F
# 这里是在学习一种调试的方式
class _Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)def forward(self, input):a1 = self.conv1(input)a2 = F.max_pool2d(a1,2)a3 = self.conv2(a2)a4 = F.max_pool2d(a3,2)# print()class Model1(nn.Module):def __init__(self) -> None:super().__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5) # 1X28X28 --> 6X24X24 # 池化 6X12X12self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5) # 6X12X12--> 16X8X8# 池化 16X4X4 self.liner_1 = nn.Linear(16*4*4,256)self.liner_2 = nn.Linear(256,10)def forward(self,input):x = F.max_pool2d(F.relu(self.conv1(input)),2)x = F.max_pool2d(F.relu(self.conv2(x)),2)# 展平层x = x.view(-1, 16*4*4)x = F.relu(self.liner_1(x))x = self.liner_2(x)return x   

相关文章:

  • 没有知网资源如何快速下载知网论文
  • 号称超级增程电动,领克07EM-P带来技术变革?
  • 【Linux内核-编程指南】
  • 小程序集arcgis地图显示自定义坐标的功能实现记录!(学习笔记)
  • 云原生架构案例分析_3.某快递公司核心业务系统云原生改造
  • C++ | Leetcode C++题解之第130题被围绕的区域
  • 大数据数据治理工具
  • 逻辑问题的基本知识总结越权支付验证码弱口令
  • android 调试UI 按钮无法点击事件问题
  • Vitis HLS 学习笔记--global_array_RAM初始化
  • 【C++初阶学习】第十二弹——stack和queue的介绍和使用
  • 群体优化算法---蜂群优化算法应用于数据挖掘
  • 基于uni-app的 年-月-日 时 时间日期范围控件
  • k8s牛客面经篇
  • 【Spring框架全系列】SpringBoot_3种配置文件_yml语法_多环境开发配置(详细)
  • 【腾讯Bugly干货分享】从0到1打造直播 App
  • 4个实用的微服务测试策略
  • javascript 哈希表
  • SpiderData 2019年2月25日 DApp数据排行榜
  • Spring核心 Bean的高级装配
  • Spring框架之我见(三)——IOC、AOP
  • Theano - 导数
  • Vim Clutch | 面向脚踏板编程……
  • vue脚手架vue-cli
  • 包装类对象
  • 基于HAProxy的高性能缓存服务器nuster
  • 基于Volley网络库实现加载多种网络图片(包括GIF动态图片、圆形图片、普通图片)...
  • 推荐一个React的管理后台框架
  • 译米田引理
  • MiKTeX could not find the script engine ‘perl.exe‘ which is required to execute ‘latexmk‘.
  • 长三角G60科创走廊智能驾驶产业联盟揭牌成立,近80家企业助力智能驾驶行业发展 ...
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • ​猴子吃桃问题:每天都吃了前一天剩下的一半多一个。
  • ​虚拟化系列介绍(十)
  • ​一些不规范的GTID使用场景
  • # wps必须要登录激活才能使用吗?
  • #Datawhale X 李宏毅苹果书 AI夏令营#3.13.2局部极小值与鞍点批量和动量
  • #前后端分离# 头条发布系统
  • #预处理和函数的对比以及条件编译
  • (1)Map集合 (2)异常机制 (3)File类 (4)I/O流
  • (done) 声音信号处理基础知识(2) (重点知识:pitch)(Sound Waveforms)
  • (Java入门)学生管理系统
  • (MATLAB)第五章-矩阵运算
  • (办公)springboot配置aop处理请求.
  • (笔记)Kotlin——Android封装ViewBinding之二 优化
  • (附源码)springboot 智能停车场系统 毕业设计065415
  • (没学懂,待填坑)【动态规划】数位动态规划
  • (每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理 第13章 项目资源管理(七)
  • (三)uboot源码分析
  • (三分钟)速览传统边缘检测算子
  • (一)spring cloud微服务分布式云架构 - Spring Cloud简介
  • (转)人的集合论——移山之道
  • .describe() python_Python-Win32com-Excel
  • .NET C# 使用 SetWindowsHookEx 监听鼠标或键盘消息以及此方法的坑
  • .NET Core 项目指定SDK版本