当前位置: 首页 > news >正文

最基础的tensorflow代码学习方法———多敲多理解

我最近在学习tensorflow实战深度学习框架,发现书上的部分代码看起来很简单,理论部分也比较基础。可是当我自己想要敲点代码的时候,发现怎么也敲不出来,于是我反思,到底哪里出了问题。原来我为了求快,迅速的把tensorflow的教程和深度学习的书本看了一遍,只是基本的知道了里面的理论和原理,具体的代码细节并不理解。这犯了初学者的大忌:只看不动手实践。
程序员的代码学习方法——多敲多理解,每学一件新的东西(算法、程序),必须要把它敲出来,只有敲出来,理解了才是你的东西。不然哪怕你copy过来实现了基本功能,你也不理解细节。
像以下最为基础的猫狗分类代码,如果不自己亲自写一遍,对于其中的细节也只能停留在表面的理解上,永远得不到具体的理解。

# -*- coding: utf-8 -*-
"""
Created on Fri Mar 19 19:27:26 2021
cat_VS_dag  猫狗分类识别
@author: chenfeng
"""

import tensorflow as tf	
from tensorflow import keras
import numpy as np
import os
import random

num_epochs = 100					#遍历数据集的次数
num_class = 2						#类别数
num_train = 23000					#训练样本总数
num_test = 2000						#测试样本总数
input_rows,input_cols = 100,100		#输入图片尺寸
batch_size = 32						#批次大小
learning_rate = 0.01				#初始学习率
loss_value = []						#记录损失值

num_parallel_calls = 4
#================================================================#
#=注意!!!num_parallel_calls是数据增强使使用的CPU核心数,根据实际调整=#
#================================================================#

DATA_DIR = 'E:/UserData/Desktop/catVSdog_dataSet/'		#数据根目录
LABLE = {0:'cats',1:'dogs'}				#分类标签
#'F:/data/cats-and-dogs/train/cats'		#猫的训练样本路径,其余类似


#============================  model  ============================#
#=================================================================#

def CNN():
    model = keras.Sequential()

    #block1
    
    model.add(keras.layers.Conv2D(32,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu,input_shape = (input_rows,input_cols,3)))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Conv2D(32,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.MaxPool2D(pool_size = [2,2]))
    model.add(keras.layers.Dropout(0.2))
    
    #block2
    
    model.add(keras.layers.Conv2D(64,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Conv2D(64,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.MaxPool2D(pool_size = [2,2]))
    model.add(keras.layers.Dropout(0.2))
    
    #block3
    
    model.add(keras.layers.Conv2D(128,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Conv2D(128,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.MaxPool2D())
    model.add(keras.layers.Dropout(0.2))
    
    #block4
    
    model.add(keras.layers.Conv2D(256,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Conv2D(256,[3,3],padding = 'same',kernel_initializer='he_normal',activation = tf.nn.elu))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.MaxPool2D())
    model.add(keras.layers.Dropout(0.5))
    
    #block5
    
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(64,activation = tf.nn.elu,kernel_initializer='he_normal'))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Dropout(0.5))
    
    #block6
    
    model.add(keras.layers.Dense(32,activation = tf.nn.elu,kernel_initializer='he_normal'))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Dropout(0.5))

    #block7
    
    model.add(keras.layers.Dense(num_class,activation = tf.nn.softmax,kernel_initializer='he_normal'))
    print(model.summary(),'\n')
    return model

#=============================  data  ============================#
#=================================================================#


 

def data_set(DATA_DIR,option):			#根据option是train还是test加载对应文件夹的数据
    if option == 'train':
        k = num_train
    else:
        k = num_test
    Filenames = []
    Lable = np.zeros(shape = (k,1))
    i = 0
    for lable in LABLE:					#lable是0即进入cats文件夹,是1则进入dogs文件夹
        for filename in os.listdir(DATA_DIR + option + '/' +  LABLE[lable]):
            Filenames.append(DATA_DIR + option + '/' + LABLE[lable] + '/' + filename)
            Lable[i] = (lable)
            i += 1
    return Filenames,Lable 				#返回图片地址list和对应标签np数组

def map_fun(filename,lable):			#数据增强函数,包含旋转,镜像,翻转
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string,channels =3)
    									#必须要进行上面两步,将图像地址转换为图像数据
    rand = random.random()
    if rand > 0.5:
        image = tf.image.rot90(image)
        
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    
    image = image / 255					#数据处理为[0,1]区间的值,据说有利于计算和收敛
    
    return image,lable 

#下面是main中的内容,给搬过来了
#========================  funs and main  ========================#
#=================================================================#
TRAN_Filenames,TRAN_Lable = data_set(DATA_DIR,'train')
tran_dataset = tf.data.Dataset.from_tensor_slices((TRAN_Filenames,TRAN_Lable))
										#上面这步需要注意注意再注意!!!一定是(TRAN_Filenames,TRAN_Lable),带上括号作为参数,不然,嘿嘿,报错都找不见在哪
tran_dataset = tran_dataset.map(map_func = map_fun,num_parallel_calls = num_parallel_calls)
tran_dataset = tran_dataset.shuffle(23000)
tran_dataset = tran_dataset.batch(batch_size)
tran_dataset = tran_dataset.prefetch(tf.data.experimental.AUTOTUNE)

TEST_Filenames,TEST_Lable = data_set(DATA_DIR,'test')
test_dataset = tf.data.Dataset.from_tensor_slices((TEST_Filenames,TEST_Lable))
test_dataset = test_dataset.map(map_func = map_fun,num_parallel_calls=num_parallel_calls)
test_dataset = test_dataset.batch(batch_size)


#========================  funs and main  ========================#
#=================================================================#


class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self,logs = {}):
        self.losses = []
        
    def on_batch_end(self,batch,logs = {}):
        self.losses.append(logs.get('loss'))



if __name__ == '__main__':
    TRAN_Filenames,TRAN_Lable = data_set(DATA_DIR,'train')
    tran_dataset = tf.data.Dataset.from_tensor_slices((TRAN_Filenames,TRAN_Lable))
    tran_dataset = tran_dataset.map(map_func = map_fun,num_parallel_calls=num_parallel_calls)
    tran_dataset = tran_dataset.shuffle(23000)
    tran_dataset = tran_dataset.batch(batch_size)
    tran_dataset = tran_dataset.prefetch(tf.data.experimental.AUTOTUNE)

    TEST_Filenames,TEST_Lable = data_set(DATA_DIR,'test')
    test_dataset = tf.data.Dataset.from_tensor_slices((TEST_Filenames,TEST_Lable))
    test_dataset = test_dataset.map(map_func = map_fun,num_parallel_calls=num_parallel_calls)
    test_dataset = test_dataset.batch(batch_size)
#================上面是data相关处理================#
	
    model = CNN()	#赋予模型
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        					loss=tf.keras.losses.sparse_categorical_crossentropy,
        					metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    				#使用compile函数定义优化器optimizer,详见资料1
    				
    				#以下四个参数与callbacks的使用有关
    reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss',
                                                  factor = 0.2,
                                                  patience = 4, 
                                                  verbose = 1,
                                                  min_delta = 0.00001)
                    #动态调整学习率
    earlystop = keras.callbacks.EarlyStopping(monitor = 'val_loss',
                                              min_delta = 0,
                                              patience = 10,
                                              verbose = 1,
                                              restore_best_weights = True)
                    #当模型不在进步时提前结束训练                          
    history = LossHistory()
    				#通过LossHistory将训练loss值保存history中
    checkpoint = keras.callbacks.ModelCheckpoint('best_model.h5',
                                                 monitor = 'val_loss',
                                                 mode = 'min',
                                                 save_best_only = True,
                                                 verbose = 1)
                   	#查看当前模型是否最优,只保存训练中最优的模型
    callbacks = [history,reduce_lr,earlystop,checkpoint]
    				#将需要参数打包成list
    model.fit(tran_dataset,batch_size = batch_size,epochs = num_epochs, callbacks = callbacks,validation_data = test_dataset)
    				#训练模型,设置callbacks,validation_data = test_dataset可以在一个epoch后用测试集进行测试,
    				#上面reduce_lr等中的'val_loss'就是测试集的loss,以此作为保存最优模型,和调整学习率的指标
    #tf.saved_model.save(model, "F:/data/cats-and-dogs")



相关文章:

  • win10下cuda版本升级
  • 如何在tensorflow2.0使用已经移除的tensorflow1.X的模块
  • tensorflow.python之Failed to get convolution algorithm定位错误
  • 毕业后,重装电脑系统,我的资料备份
  • 轻松解决问题Could not load dynamic library cudart64_101.dll
  • c++的多态和虚函数的复习
  • MFC获取主窗口的句柄
  • MFC复制项目界面,代码合并问题(更加全面针对各种小问题)
  • 职业规划心得
  • 解放自我----真实维度的世界
  • 知识连接比知识本身更加重要
  • 人是高维度自定义的机器,自己给自己编辑程序
  • 如何彻底解决浏览器导航被劫持为www.hao123.com
  • 刷题笔记155:最小栈
  • Anaconda之导出/导出配置好的虚拟环境
  • 【附node操作实例】redis简明入门系列—字符串类型
  • JavaScript标准库系列——Math对象和Date对象(二)
  • Js基础知识(一) - 变量
  • markdown编辑器简评
  • Mybatis初体验
  • Nacos系列:Nacos的Java SDK使用
  • Redux 中间件分析
  • Vultr 教程目录
  • Webpack 4x 之路 ( 四 )
  • WinRAR存在严重的安全漏洞影响5亿用户
  • 持续集成与持续部署宝典Part 2:创建持续集成流水线
  • 从tcpdump抓包看TCP/IP协议
  • 关键词挖掘技术哪家强(一)基于node.js技术开发一个关键字查询工具
  • 官方新出的 Kotlin 扩展库 KTX,到底帮你干了什么?
  • 深入体验bash on windows,在windows上搭建原生的linux开发环境,酷!
  • 一个普通的 5 年iOS开发者的自我总结,以及5年开发经历和感想!
  • 掌握面试——弹出框的实现(一道题中包含布局/js设计模式)
  • 最近的计划
  • python最赚钱的4个方向,你最心动的是哪个?
  • #define,static,const,三种常量的区别
  • #我与Java虚拟机的故事#连载02:“小蓝”陪伴的日日夜夜
  • $ git push -u origin master 推送到远程库出错
  • (9)目标检测_SSD的原理
  • (超详细)语音信号处理之特征提取
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (附源码)SSM环卫人员管理平台 计算机毕设36412
  • (六)库存超卖案例实战——使用mysql分布式锁解决“超卖”问题
  • (求助)用傲游上csdn博客时标签栏和网址栏一直显示袁萌 的头像
  • (算法二)滑动窗口
  • (万字长文)Spring的核心知识尽揽其中
  • (学习日记)2024.04.04:UCOSIII第三十二节:计数信号量实验
  • (一) springboot详细介绍
  • (转)winform之ListView
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .NET 8 编写 LiteDB vs SQLite 数据库 CRUD 接口性能测试(准备篇)
  • .NET开发不可不知、不可不用的辅助类(三)(报表导出---终结版)
  • .NET开源快速、强大、免费的电子表格组件
  • .net中的Queue和Stack
  • ?
  • [ 数据结构 - C++] AVL树原理及实现