当前位置: 首页 > news >正文

keras卷积处理rgb输入_用Keras搭建卷积神经网络 以及 使用Colaboratory

本文我们将学习使用Keras一步一步搭建一个卷积神经网络。具体来说,我们将使用卷积神经网络对手写数字(MNIST数据集)进行识别,并达到99%以上的正确率。

v2-9a670cc6f4c09c02d7a17a53b1524d33_b.jpg

我们还将介绍Colaboratory——一个免费的Jupyter notebook环境,关键的是可以免费使用GPU(学生党买不起呀)!

v2-b17b374ef863bd771509395d2a7c0ca4_b.jpg

为什么选择Keras呢?

主要是因为简单方便。更多细节请看:https://keras.io/

什么卷积神经网络?

简单地说,卷积神经网络(CNNs)是一种多层神经网络,它可以有效地减少全连接神经网络参数量太大的问题。如果对其背后的原理感兴趣的话,斗胆推荐一些学习资料:

深度学习入门:基于Python的理论与实现

Neural Networks and Deep Learning

CS231n: Convolutional Neural Networks for Visual Recognition

下面就直接进入主题吧!

1. 搭建环境

如果想要在个人电脑上搭建的话,我们需要先安装好Python,进入:https://www.python.org/

v2-ab8273a4af5b90acb5924c71eabac693_b.jpg

下载安装就好。

之后,打开终端输入pip install -i https://pypi.douban.com/simple keras

v2-2a52baf272829fd17e4a862608e2bba9_b.jpg

输入以下命令可以确认正常安装:

python -c "import keras;print(keras.__version__)"

v2-29968e107bd6de088087d8d7af0e7dcc_b.png

当然,如果想直接使用Colaboratory的话,直接打开你的Google云端硬盘

v2-1d56410b037f4c4de9998d6ffd16ef34_b.jpg

为了方便起见,新建一个名为Keras的文件夹,进入Keras文件夹,单击鼠标右键

v2-d82cc8ca4600f0516e730352a2f990d3_b.jpg

选择Colaboratory就可新建一个Jupyter notebok啦!

如果没有看到Colaboratory这一项的话,就选择关联更多应用

v2-56de8543df1b1503a927a53da0231b35_b.jpg

搜索Colaboratory,并关联即可。

2. 导入库和模块

我们导入Sequential模型(相当于放积木的桌子)

from keras.models import Sequential

接下来,我们导入各种层(各种形状各异积木)

from keras.layers import Conv2D, MaxPool2D
from keras.layers import Dense, Flatten

最后,我们导入to_categorical函数,以便之后对数据进行转换

from keras.utils import to_categorical

3. 加载数据

MNIST是一个非常有名的手写数字数据集,我们可以使用Keras轻松加载它。

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

查看一下训练集的大小

print(x_train.shape)

# (60000, 28, 28)

可以看到60000个样本,它们都是28像素x28像素的。

看一下这些手写数字长什么样

import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(x_train[0])

v2-b49bd83165aca1eb4c07ca175b0ba1b8_b.jpg

4. 预处理数据

使用Keras是必须显式声明输入图像深度的尺寸。例如,具有所有3个RGB通道的全色图像的深度为3。

我们的MNIST图像的深度为1,但我们必须明确声明。

也就是说,我们希望将数据集从形状(n,rows,cols)转换为(n,rows,cols,channels)

img_x, img_y = 28, 28

x_train = x_train.reshape(x_train.shape[0], img_x, img_y, 1)
x_test = x_test.reshape(x_test.shape[0], img_x, img_y, 1)

除此之外,我们将数据标准化一下:

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

之后,将标记值(y_train, y_test)转换为One-Hot Encode的形式,至于为什么要这么做?请查看:https://machinelearningmastery.com/why-one-hot-encode-data-in-machine-learning/

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
print(y_train.shape)
# (60000, 10)

5. 定义模型结构

我们参照下图定义一个模型结构

v2-298160770eedc4bfd48673b1a449fded_b.jpg

代码如下:

model = Sequential()
model.add(Conv2D(32, kernel_size=(5,5), activation='relu', input_shape=(img_x, img_y, 1)))
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Conv2D(64, kernel_size=(5,5), activation='relu'))
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(Dense(10, activation='softmax'))

是不是有点搭积木的既视感?

6. 编译

现在,只需要编译模型,就可以开始训练了。当编译模型时,我们声明了损失函数和优化器(SGD,Adam等)。

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Keras有很多损失函数和优化器供你选择。

7. 训练

接下来,我们传入训练集进行训练

model.fit(x_train, y_train, batch_size=128, epochs=10)

以下是在Colaboratory上训练的过程

v2-e8f4eb87158cb60e83d4545f9da417a8_b.jpg

以下是在个人电脑上训练的过程

v2-558dcc47a9ebb70fd4156948bfd086fa_b.jpg

可以看到,花费的时间差别还是很大的!

8. 评估模型

最后,传入测试集对模型模型进行评估

score = model.evaluate(x_test, y_test)
print('acc', score[1])
# acc 0.9926

准确率达到了%99以上!

完整代码如下:

# 2. 导入库和模块
from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D
from keras.layers import Dense, Flatten
from keras.utils import to_categorical

# 3. 加载数据
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 4. 数据预处理
img_x, img_y = 28, 28
x_train = x_train.reshape(x_train.shape[0], img_x, img_y, 1)
x_test = x_test.reshape(x_test.shape[0], img_x, img_y, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# 5. 定义模型结构
model = Sequential()
model.add(Conv2D(32, kernel_size=(5,5), activation='relu', input_shape=(img_x, img_y, 1)))
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Conv2D(64, kernel_size=(5,5), activation='relu'))
model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(Dense(10, activation='softmax'))

# 6. 编译
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 7. 训练
model.fit(x_train, y_train, batch_size=128, epochs=10)

# 8. 评估模型
score = model.evaluate(x_test, y_test)
print('acc', score[1])

参考

[1] https://elitedatascience.com/keras-tutorial-deep-learning-in-python

[2] http://adventuresinmachinelearning.com/keras-tutorial-cnn-11-lines/

相关文章:

  • android sdk manager安装地址_Creator | 配置Android发布环境
  • PHP的header()
  • android 组件化_Android组件化最佳实践ARetrofit
  • Java Servlet 缺点
  • invoke方法_PHP魔术方法
  • PV、UV、IP的区别
  • mysqld:表mysql.plugin不存在_详解MySQL Binlog解析工具--binlog2sql,基于表级别的数据恢复...
  • css的经典三栏式布局
  • mysql text字段导出_看完能涨工资的MySQL性能优化指南
  • python生成list_python 使用循环生成list
  • 关于vue中的nextTick深入理解
  • hashmap 遍历_你一般是怎么遍历HashMap的?
  • 团队名称:极限定理
  • 人工智能的三层基本架构_“人工智能”如何“深度学习”?
  • SQLite3使用详解
  • 【comparator, comparable】小总结
  • 【前端学习】-粗谈选择器
  • 2017-09-12 前端日报
  • 230. Kth Smallest Element in a BST
  • 4. 路由到控制器 - Laravel从零开始教程
  • ES2017异步函数现已正式可用
  • Gradle 5.0 正式版发布
  • Java比较器对数组,集合排序
  • jQuery(一)
  • js
  • Laravel5.4 Queues队列学习
  • ng6--错误信息小结(持续更新)
  • node-sass 安装卡在 node scripts/install.js 解决办法
  • Python - 闭包Closure
  • Sequelize 中文文档 v4 - Getting started - 入门
  • spring boot 整合mybatis 无法输出sql的问题
  • Theano - 导数
  • 编写高质量JavaScript代码之并发
  • 理解 C# 泛型接口中的协变与逆变(抗变)
  • 力扣(LeetCode)22
  • 聊聊flink的BlobWriter
  • 小程序开发之路(一)
  • 携程小程序初体验
  • 赢得Docker挑战最佳实践
  • Hibernate主键生成策略及选择
  • ​软考-高级-信息系统项目管理师教程 第四版【第14章-项目沟通管理-思维导图】​
  • (11)工业界推荐系统-小红书推荐场景及内部实践【粗排三塔模型】
  • (HAL)STM32F103C6T8——软件模拟I2C驱动0.96寸OLED屏幕
  • (Mac上)使用Python进行matplotlib 画图时,中文显示不出来
  • (附源码)springboot宠物医疗服务网站 毕业设计688413
  • (六) ES6 新特性 —— 迭代器(iterator)
  • (转)fock函数详解
  • (转)VC++中ondraw在什么时候调用的
  • (转)程序员技术练级攻略
  • (转)关于多人操作数据的处理策略
  • ***检测工具之RKHunter AIDE
  • **PHP二维数组遍历时同时赋值
  • .naturalWidth 和naturalHeight属性,
  • .NET Core 网络数据采集 -- 使用AngleSharp做html解析
  • .NET MVC第三章、三种传值方式