当前位置: 首页 > news >正文

TensorFlow:实战Google深度学习框架 # LeNet-5模型

与前文link~对比


mnist_inference.py

import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10

IMAGE_SIZE = 28
NUM_CHANNELS = 1
NUM_LABELS = 10

CONV1_DEEP = 32
CONV1_SIZE = 5

CONV2_DEEP = 64
CONV2_SIZE = 5

FC_SIZE = 512

def inference(input_tensor, train, regularizer):
    with tf.variable_scope('layer1-conv1'):
        conv1_weights = tf.get_variable(
            "weight", [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv1_biases = tf.get_variable("bias", [CONV1_DEEP], initializer=tf.constant_initializer(0.0))
        conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))

    with tf.name_scope("layer2-pool1"):
        pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="SAME")

    with tf.variable_scope("layer3-conv2"):
        conv2_weights = tf.get_variable(
            "weight", [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP],
            initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv2_biases = tf.get_variable("bias", [CONV2_DEEP], initializer=tf.constant_initializer(0.0))
        conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))

    with tf.name_scope("layer4-pool2"):
        pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
        pool_shape = pool2.get_shape().as_list()
        nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
        reshaped = tf.reshape(pool2, [pool_shape[0], nodes])

    with tf.variable_scope('layer5-fc1'):
        fc1_weights = tf.get_variable("weight", [nodes, FC_SIZE],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(fc1_weights))
        fc1_biases = tf.get_variable("bias", [FC_SIZE], initializer=tf.constant_initializer(0.1))

        fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)
        if train: fc1 = tf.nn.dropout(fc1, 0.5)

    with tf.variable_scope('layer6-fc2'):
        fc2_weights = tf.get_variable("weight", [FC_SIZE, NUM_LABELS],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(fc2_weights))
        fc2_biases = tf.get_variable("bias", [NUM_LABELS], initializer=tf.constant_initializer(0.1))
        logit = tf.matmul(fc1, fc2_weights) + fc2_biases

    return logit

LeNet5_train.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import LeNet5_infernece
import os
import numpy as np

BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH="MNIST_model/"
MODEL_NAME="mnist_model"


def train(mnist):

    # x = tf.placeholder(tf.float32, [None, LeNet5_infernece.INPUT_NODE], name='x-input')
    x = tf.placeholder(tf.float32,
        [
            BATCH_SIZE,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.NUM_CHANNELS
        ],
        name="x-input"
    )
    y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input')

    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = LeNet5_infernece.inference(x,True,regularizer)
    global_step = tf.Variable(0, trainable=False)


    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
        staircase=True)

    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')


    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()

        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            reshaped_xs = np.reshape(xs,(
                BATCH_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.NUM_CHANNELS
            ))
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})
            if i % 1000 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)


def main(argv=None):
    mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    tf.app.run()

LeNet5_eval.py

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import LeNet5_infernece
import LeNet5_train
import numpy as np

# 加载的时间间隔。
EVAL_INTERVAL_SECS = 10

def evaluate(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32,
            [
                mnist.test.num_examples,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.NUM_CHANNELS
            ],
            name="x-input"
        )
        y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input')
        y = LeNet5_infernece.inference(x,None,None)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        variable_averages = tf.train.ExponentialMovingAverage(LeNet5_train.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(LeNet5_train.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    reshaped_x = np.reshape(mnist.test.images,(
                            mnist.test.num_examples,
                            LeNet5_infernece.IMAGE_SIZE,
                            LeNet5_infernece.IMAGE_SIZE,
                            LeNet5_infernece.NUM_CHANNELS
                        )
                    )
                    accuracy_score = sess.run(accuracy, feed_dict={x:reshaped_x,y_:mnist.test.labels})
                    print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(EVAL_INTERVAL_SECS)

def main(argv=None):
    mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
    evaluate(mnist)

if __name__ == '__main__':
    main()


相关文章:

  • Layui中引入$符
  • vue 监听对象时失效 / 监听对象属性改变
  • 修改elementUI组件的样式
  • vue 如何修改数组中对象的属性?
  • 【转】大型Vuex项目 ,使用module后, 如何调用其他模块的 属性值和方法
  • 跨域问题解决方法: chrome浏览器关闭CORS策略
  • 前后端分离 webpack配置api代理
  • vue基础 笔记
  • 跨域情况下前端获取response header的问题
  • Python 生成二维码
  • Python 图片转字符画
  • CSS 一个炫酷的盒子效果
  • CSS 清除浮动
  • html+css课程综合案例
  • js 两个网页之间通过URL传输数据
  • [PHP内核探索]PHP中的哈希表
  • hexo+github搭建个人博客
  • Android交互
  • Android框架之Volley
  • Django 博客开发教程 8 - 博客文章详情页
  • Intervention/image 图片处理扩展包的安装和使用
  • JavaScript设计模式与开发实践系列之策略模式
  • JavaScript学习总结——原型
  • Java新版本的开发已正式进入轨道,版本号18.3
  • react 代码优化(一) ——事件处理
  • React-flux杂记
  • 从@property说起(二)当我们写下@property (nonatomic, weak) id obj时,我们究竟写了什么...
  • 第2章 网络文档
  • 服务器之间,相同帐号,实现免密钥登录
  • 前端代码风格自动化系列(二)之Commitlint
  • 前端性能优化——回流与重绘
  • 如何胜任知名企业的商业数据分析师?
  • 深度学习入门:10门免费线上课程推荐
  • 网络应用优化——时延与带宽
  • 我有几个粽子,和一个故事
  • 在 Chrome DevTools 中调试 JavaScript 入门
  • 基于django的视频点播网站开发-step3-注册登录功能 ...
  • #Linux(Source Insight安装及工程建立)
  • #微信小程序(布局、渲染层基础知识)
  • #我与Java虚拟机的故事#连载16:打开Java世界大门的钥匙
  • $var=htmlencode(“‘);alert(‘2“); 的个人理解
  • (DenseNet)Densely Connected Convolutional Networks--Gao Huang
  • (附源码)spring boot校园健康监测管理系统 毕业设计 151047
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理第3章 信息系统治理(一)
  • (篇九)MySQL常用内置函数
  • (亲测成功)在centos7.5上安装kvm,通过VNC远程连接并创建多台ubuntu虚拟机(ubuntu server版本)...
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (五)IO流之ByteArrayInput/OutputStream
  • (转)编辑寄语:因为爱心,所以美丽
  • (转)用.Net的File控件上传文件的解决方案
  • (轉)JSON.stringify 语法实例讲解
  • * 论文笔记 【Wide Deep Learning for Recommender Systems】
  • *ST京蓝入股力合节能 着力绿色智慧城市服务