当前位置: 首页 > news >正文

chap5 CNN

卷积神经网络(CNN)

问题描述:

利用卷积神经网络,实现对MNIST数据集的分类问题

数据集:

MNIST数据集包括60000张训练图片和10000张测试图片。图片样本的数量已经足够训练一个很复杂的模型(例如 CNN的深层神经网络)。它经常被用来作为一个新 的模式识别模型的测试用例。而且它也是一个方便学生和研究者们执行用例的数据集。除此之外,MNIST数据集是一个相对较小的数据集,可以在你的笔记本CPUs上面直接执行

题目要求

Pytorch版本的卷积神经网络需要补齐self.conv1中的nn.Conv2d()self.conv2()的参数,还需要填写x=x.view()中的内容。
训练精度应该在96%以上。

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np
learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50DOWNLOAD_MNIST = False
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):# not mnist dir or mnist is empyt dirDOWNLOAD_MNIST = Truetrain_data = torchvision.datasets.MNIST(root='./mnist/',train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)test_data = torchvision.datasets.MNIST(root = './mnist/',train = False)
test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy()class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d( # ???# patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1# padding style is same(that means the convolution opration's input and output have the same size)in_channels=1,out_channels=32,kernel_size=7,stride=1,padding=3,),nn.ReLU(),        # activation functionnn.MaxPool2d(2),  # pooling operation)self.conv2 = nn.Sequential( # ???# line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1# line 2 : choosing your activation funciont# line 3 : pooling operation function.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2, stride=1),nn.ReLU(),nn.AvgPool2d(2),)self.out1 = nn.Linear( 7*7*64 , 1024 , bias= True)   # full connection layer oneself.dropout = nn.Dropout(keep_prob_rate)self.out2 = nn.Linear(1024,10,bias=True)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(-1, 7*7*64)  # flatten the output of coonv2 to (batch_size ,32 * 7 * 7)    # ???out1 = self.out1(x)out1 = F.relu(out1)out1 = self.dropout(out1)out2 = self.out2(out1)output = F.softmax(out2)return outputdef test(cnn):global predictiony_pre = cnn(test_x)_,pre_index= torch.max(y_pre,1)pre_index= pre_index.view(-1)prediction = pre_index.data.numpy()correct  = np.sum(prediction == test_y)return correct / 500.0def train(cnn):optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )loss_func = nn.CrossEntropyLoss()for epoch in range(max_epoch):for step, (x_, y_) in enumerate(train_loader):x ,y= Variable(x_),Variable(y_)output = cnn(x)  loss = loss_func(output,y)optimizer.zero_grad()loss.backward()optimizer.step()if step != 0 and step % 20 ==0:print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )if __name__ == '__main__':cnn = CNN()train(cnn)

训练结果为:

========== 20 ===== ===== test accuracy is  0.224 ==========
========== 40 ===== ===== test accuracy is  0.362 ==========
========== 60 ===== ===== test accuracy is  0.402 ==========
========== 80 ===== ===== test accuracy is  0.51 ==========
========== 100 ===== ===== test accuracy is  0.608 ==========
========== 120 ===== ===== test accuracy is  0.624 ==========
========== 140 ===== ===== test accuracy is  0.708 ==========
========== 160 ===== ===== test accuracy is  0.684 ==========
========== 180 ===== ===== test accuracy is  0.738 ==========
========== 200 ===== ===== test accuracy is  0.766 ==========
========== 220 ===== ===== test accuracy is  0.778 ==========
========== 240 ===== ===== test accuracy is  0.796 ==========
========== 260 ===== ===== test accuracy is  0.802 ==========
========== 280 ===== ===== test accuracy is  0.81 ==========
========== 300 ===== ===== test accuracy is  0.812 ==========
========== 320 ===== ===== test accuracy is  0.82 ==========
========== 340 ===== ===== test accuracy is  0.848 ==========
========== 360 ===== ===== test accuracy is  0.83 ==========
========== 380 ===== ===== test accuracy is  0.852 ==========
========== 400 ===== ===== test accuracy is  0.852 ==========
========== 420 ===== ===== test accuracy is  0.856 ==========
========== 440 ===== ===== test accuracy is  0.874 ==========
========== 460 ===== ===== test accuracy is  0.85 ==========
========== 480 ===== ===== test accuracy is  0.874 ==========
========== 500 ===== ===== test accuracy is  0.864 ==========
========== 520 ===== ===== test accuracy is  0.858 ==========
========== 540 ===== ===== test accuracy is  0.884 ==========
========== 560 ===== ===== test accuracy is  0.872 ==========
========== 580 ===== ===== test accuracy is  0.9 ==========
========== 600 ===== ===== test accuracy is  0.88 ==========
========== 620 ===== ===== test accuracy is  0.886 ==========
========== 640 ===== ===== test accuracy is  0.882 ==========
========== 660 ===== ===== test accuracy is  0.886 ==========
========== 680 ===== ===== test accuracy is  0.876 ==========
========== 700 ===== ===== test accuracy is  0.882 ==========
========== 720 ===== ===== test accuracy is  0.886 ==========
========== 740 ===== ===== test accuracy is  0.894 ==========
========== 760 ===== ===== test accuracy is  0.894 ==========
========== 780 ===== ===== test accuracy is  0.9 ==========
========== 800 ===== ===== test accuracy is  0.898 ==========
========== 820 ===== ===== test accuracy is  0.912 ==========
========== 840 ===== ===== test accuracy is  0.894 ==========
========== 860 ===== ===== test accuracy is  0.898 ==========
========== 880 ===== ===== test accuracy is  0.888 ==========
========== 900 ===== ===== test accuracy is  0.896 ==========
========== 920 ===== ===== test accuracy is  0.888 ==========
========== 940 ===== ===== test accuracy is  0.91 ==========
========== 960 ===== ===== test accuracy is  0.908 ==========
========== 980 ===== ===== test accuracy is  0.918 ==========
========== 1000 ===== ===== test accuracy is  0.906 ==========
========== 1020 ===== ===== test accuracy is  0.908 ==========
========== 1040 ===== ===== test accuracy is  0.906 ==========
========== 1060 ===== ===== test accuracy is  0.914 ==========
========== 1080 ===== ===== test accuracy is  0.908 ==========
========== 1100 ===== ===== test accuracy is  0.906 ==========
========== 1120 ===== ===== test accuracy is  0.906 ==========
========== 1140 ===== ===== test accuracy is  0.924 ==========
========== 1160 ===== ===== test accuracy is  0.918 ==========
========== 1180 ===== ===== test accuracy is  0.904 ==========
========== 20 ===== ===== test accuracy is  0.924 ==========
========== 40 ===== ===== test accuracy is  0.908 ==========
========== 60 ===== ===== test accuracy is  0.92 ==========
========== 80 ===== ===== test accuracy is  0.91 ==========
========== 100 ===== ===== test accuracy is  0.926 ==========
========== 120 ===== ===== test accuracy is  0.91 ==========
========== 140 ===== ===== test accuracy is  0.922 ==========
========== 160 ===== ===== test accuracy is  0.932 ==========
========== 180 ===== ===== test accuracy is  0.932 ==========
========== 200 ===== ===== test accuracy is  0.93 ==========
========== 220 ===== ===== test accuracy is  0.94 ==========
========== 240 ===== ===== test accuracy is  0.918 ==========
========== 260 ===== ===== test accuracy is  0.934 ==========
========== 280 ===== ===== test accuracy is  0.93 ==========
========== 300 ===== ===== test accuracy is  0.934 ==========
========== 320 ===== ===== test accuracy is  0.934 ==========
========== 340 ===== ===== test accuracy is  0.93 ==========
========== 360 ===== ===== test accuracy is  0.944 ==========
========== 380 ===== ===== test accuracy is  0.938 ==========
========== 400 ===== ===== test accuracy is  0.92 ==========
========== 420 ===== ===== test accuracy is  0.936 ==========
========== 440 ===== ===== test accuracy is  0.948 ==========
========== 460 ===== ===== test accuracy is  0.934 ==========
========== 480 ===== ===== test accuracy is  0.938 ==========
========== 500 ===== ===== test accuracy is  0.916 ==========
========== 520 ===== ===== test accuracy is  0.916 ==========
========== 540 ===== ===== test accuracy is  0.928 ==========
========== 560 ===== ===== test accuracy is  0.936 ==========
========== 580 ===== ===== test accuracy is  0.942 ==========
========== 600 ===== ===== test accuracy is  0.922 ==========
========== 620 ===== ===== test accuracy is  0.94 ==========
========== 640 ===== ===== test accuracy is  0.94 ==========
========== 660 ===== ===== test accuracy is  0.96 ==========
========== 680 ===== ===== test accuracy is  0.938 ==========
========== 700 ===== ===== test accuracy is  0.936 ==========
========== 720 ===== ===== test accuracy is  0.94 ==========
========== 740 ===== ===== test accuracy is  0.946 ==========
========== 760 ===== ===== test accuracy is  0.946 ==========
========== 780 ===== ===== test accuracy is  0.948 ==========
========== 800 ===== ===== test accuracy is  0.95 ==========
========== 820 ===== ===== test accuracy is  0.948 ==========
========== 840 ===== ===== test accuracy is  0.95 ==========
========== 860 ===== ===== test accuracy is  0.94 ==========
========== 880 ===== ===== test accuracy is  0.956 ==========
========== 900 ===== ===== test accuracy is  0.944 ==========
========== 920 ===== ===== test accuracy is  0.948 ==========
========== 940 ===== ===== test accuracy is  0.95 ==========
========== 960 ===== ===== test accuracy is  0.944 ==========
========== 980 ===== ===== test accuracy is  0.94 ==========
========== 1000 ===== ===== test accuracy is  0.946 ==========
========== 1020 ===== ===== test accuracy is  0.952 ==========
========== 1040 ===== ===== test accuracy is  0.952 ==========
========== 1060 ===== ===== test accuracy is  0.944 ==========
========== 1080 ===== ===== test accuracy is  0.956 ==========
========== 1100 ===== ===== test accuracy is  0.96 ==========
========== 1120 ===== ===== test accuracy is  0.948 ==========
========== 1140 ===== ===== test accuracy is  0.942 ==========
========== 1160 ===== ===== test accuracy is  0.948 ==========
========== 1180 ===== ===== test accuracy is  0.944 ==========
========== 20 ===== ===== test accuracy is  0.952 ==========
========== 40 ===== ===== test accuracy is  0.96 ==========
========== 60 ===== ===== test accuracy is  0.948 ==========
========== 80 ===== ===== test accuracy is  0.954 ==========
========== 100 ===== ===== test accuracy is  0.948 ==========
========== 120 ===== ===== test accuracy is  0.948 ==========
========== 140 ===== ===== test accuracy is  0.958 ==========
========== 160 ===== ===== test accuracy is  0.942 ==========
========== 180 ===== ===== test accuracy is  0.948 ==========
========== 200 ===== ===== test accuracy is  0.952 ==========
========== 220 ===== ===== test accuracy is  0.952 ==========
========== 240 ===== ===== test accuracy is  0.95 ==========
========== 260 ===== ===== test accuracy is  0.966 ==========
========== 280 ===== ===== test accuracy is  0.96 ==========
========== 300 ===== ===== test accuracy is  0.956 ==========
========== 320 ===== ===== test accuracy is  0.96 ==========
========== 340 ===== ===== test accuracy is  0.956 ==========
========== 360 ===== ===== test accuracy is  0.956 ==========
========== 380 ===== ===== test accuracy is  0.954 ==========
========== 400 ===== ===== test accuracy is  0.96 ==========
========== 420 ===== ===== test accuracy is  0.966 ==========
========== 440 ===== ===== test accuracy is  0.96 ==========
========== 460 ===== ===== test accuracy is  0.954 ==========
========== 480 ===== ===== test accuracy is  0.968 ==========
========== 500 ===== ===== test accuracy is  0.958 ==========
========== 520 ===== ===== test accuracy is  0.958 ==========
========== 540 ===== ===== test accuracy is  0.962 ==========
========== 560 ===== ===== test accuracy is  0.968 ==========
========== 580 ===== ===== test accuracy is  0.958 ==========
========== 600 ===== ===== test accuracy is  0.952 ==========
========== 620 ===== ===== test accuracy is  0.95 ==========
========== 640 ===== ===== test accuracy is  0.964 ==========
========== 660 ===== ===== test accuracy is  0.962 ==========
========== 680 ===== ===== test accuracy is  0.96 ==========
========== 700 ===== ===== test accuracy is  0.962 ==========
========== 720 ===== ===== test accuracy is  0.964 ==========
========== 740 ===== ===== test accuracy is  0.958 ==========
========== 760 ===== ===== test accuracy is  0.96 ==========
========== 780 ===== ===== test accuracy is  0.972 ==========
========== 800 ===== ===== test accuracy is  0.962 ==========
========== 820 ===== ===== test accuracy is  0.968 ==========
========== 840 ===== ===== test accuracy is  0.964 ==========
========== 860 ===== ===== test accuracy is  0.96 ==========
========== 880 ===== ===== test accuracy is  0.964 ==========
========== 900 ===== ===== test accuracy is  0.96 ==========
========== 920 ===== ===== test accuracy is  0.96 ==========
========== 940 ===== ===== test accuracy is  0.97 ==========
========== 960 ===== ===== test accuracy is  0.956 ==========
========== 980 ===== ===== test accuracy is  0.966 ==========
========== 1000 ===== ===== test accuracy is  0.964 ==========
========== 1020 ===== ===== test accuracy is  0.964 ==========
========== 1040 ===== ===== test accuracy is  0.97 ==========
========== 1060 ===== ===== test accuracy is  0.974 ==========
========== 1080 ===== ===== test accuracy is  0.962 ==========
========== 1100 ===== ===== test accuracy is  0.97 ==========
========== 1120 ===== ===== test accuracy is  0.974 ==========
========== 1140 ===== ===== test accuracy is  0.978 ==========
========== 1160 ===== ===== test accuracy is  0.976 ==========
========== 1180 ===== ===== test accuracy is  0.974 ==========

在这里插入图片描述

相关文章:

  • 使用 Vue 3 和 vue-print-nb 插件实现复杂申请表的打印
  • 大宋咨询(深圳车主满意度调查)如何开展汽车展会观众满意度问卷调查
  • JVM思维导图
  • java配置文件解析yml/xml/properties文件
  • 成绩发布小程序哪个好用?
  • 【Word】调整列表符号与后续文本的间距
  • 【Linux】常见命令:fping的介绍和用法举例
  • 线程思维导图
  • 【JS重点知识02】栈、堆与数据类型 关系
  • 【前端视野下的数据库概念探秘】——信息化人员必备知识面试宝典:解码“视图”与“游标”
  • Ollama+OpenWebUI+Phi3本地大模型入门
  • 安卓手机在开发者模式下 打开wifi调试功能的相关 adb 命令
  • vue canvas绘制信令图二、
  • Crosslink-NX器件应用连载(10): 图像输入并通过HDMI输出
  • 前端面试问题:子组件的某一个方法调用执行逻辑由父组件的属性状态变化来决定
  • 【译】JS基础算法脚本:字符串结尾
  • 4. 路由到控制器 - Laravel从零开始教程
  • angular2 简述
  • co模块的前端实现
  • Docker入门(二) - Dockerfile
  • IDEA常用插件整理
  • iOS 颜色设置看我就够了
  • Java超时控制的实现
  • Java精华积累:初学者都应该搞懂的问题
  • Redis中的lru算法实现
  • SQL 难点解决:记录的引用
  • SQLServer之索引简介
  • 编写符合Python风格的对象
  • 对象管理器(defineProperty)学习笔记
  • 多线程事务回滚
  • 给初学者:JavaScript 中数组操作注意点
  • 好的网址,关于.net 4.0 ,vs 2010
  • 欢迎参加第二届中国游戏开发者大会
  • 前端技术周刊 2018-12-10:前端自动化测试
  • 前端临床手札——文件上传
  • 驱动程序原理
  • 为物联网而生:高性能时间序列数据库HiTSDB商业化首发!
  • 小试R空间处理新库sf
  • 你对linux中grep命令知道多少?
  • 阿里云服务器如何修改远程端口?
  • 好程序员web前端教程分享CSS不同元素margin的计算 ...
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • ​​​​​​​​​​​​​​Γ函数
  • ​LeetCode解法汇总2670. 找出不同元素数目差数组
  • #!/usr/bin/python与#!/usr/bin/env python的区别
  • #NOIP 2014# day.1 T3 飞扬的小鸟 bird
  • (1) caustics\
  • (10)ATF MMU转换表
  • (173)FPGA约束:单周期时序分析或默认时序分析
  • (2)空速传感器
  • (6)STL算法之转换
  • (c语言+数据结构链表)项目:贪吃蛇
  • (vue)el-cascader级联选择器按勾选的顺序传值,摆脱层级约束
  • (第二周)效能测试
  • (四)搭建容器云管理平台笔记—安装ETCD(不使用证书)