当前位置: 首页 > news >正文

基于keras与tensorflow手工实现ResNet50网络

前言

在文章 基于tensorflow的ResNet50V2网络识别动物,我们使用了keras已经提供的神经网络,完成了图像分类的。这个时候,小明同学就问了,那么我怎么自己去写一个神经网络来进行训练呢?
本文就基于tensorflow,自己定一个神经网络。

ResNet50网络

在这里插入图片描述
从结构上看,与我们之前的的区别在于,输入的格式变成(3,224,224)
ResNet50有两个基本的块,分别名为Conv BlockIdentity Block

整体架构

在这里插入图片描述

Conv Block架构

在这里插入图片描述

Identity Block架构

在这里插入图片描述

模型训练

手工实现模型

模型代码(resnet50.py)

# 根据模型进行引入
from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor], name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    return model

模型训练(resnet50_model_train.py)

将 基于tensorflow的ResNet50V2网络识别动物的模型训练代码进行少量改造

import os
import pandas as pd

# Model
import keras
from keras.preprocessing.image import ImageDataGenerator

# Callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Pre-Trained Model
import tensorflow as tf
import resnet50

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)

print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")


class_dis = [len(os.listdir(root_path+name)) for name in class_names]


train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
valid_gen = ImageDataGenerator(rescale=1/255.)
test_gen = ImageDataGenerator(rescale=1/255)

# Load Data
train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)



with tf.device("/GPU:0"):
    ## Pre-Trained Model
    model = resnet50.ResNet50()
    model.summary()

    model_file = "ResNet50_V1.h5"
    # 加载预训练模型
    if os.path.exists(model_file):
        model.load_weights(model_file)

    # Callbacks
    cbs = [EarlyStopping(patience=5, restore_best_weights=True), ModelCheckpoint(model_file, save_best_only=True)]

    # Model
    opt = tf.keras.optimizers.Adam(learning_rate=2e-3)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

    # Model Training
    history = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=200)

    data = pd.DataFrame(history.history)
    print(data)

模型执行

GPU基本跑满
在这里插入图片描述

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d (ZeroPadding2D)  (None, 230, 230, 3)  0          ['input_1[0][0]']                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 64  9472        ['zero_padding2d[0][0]']         
                                )                                                                 
                                                                                                  
 bn_conv1 (BatchNormalization)  (None, 112, 112, 64  256         ['conv1[0][0]']                  
                                )                                                                 
                                                                                                  
 activation (Activation)        (None, 112, 112, 64  0           ['bn_conv1[0][0]']               
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 55, 55, 64)   0           ['activation[0][0]']             
                                                                                                  
 2a_conv_block_conv1 (Conv2D)   (None, 55, 55, 64)   4160        ['max_pooling2d[0][0]']          
                                                                                                  
 2a_conv_block_bn1 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv1[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu1 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn1[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv2 (Conv2D)   (None, 55, 55, 64)   36928       ['2a_conv_block_relu1[0][0]']    
                                                                                                  
 2a_conv_block_bn2 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv2[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu2 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn2[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv3 (Conv2D)   (None, 55, 55, 256)  16640       ['2a_conv_block_relu2[0][0]']    
                                                                                                  
 2a_conv_block_res_conv (Conv2D  (None, 55, 55, 256)  16640      ['max_pooling2d[0][0]']          
 )                                                                                                
                                                                                                  
 2a_conv_block_bn3 (BatchNormal  (None, 55, 55, 256)  1024       ['2a_conv_block_conv3[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_res_bn (BatchNor  (None, 55, 55, 256)  1024       ['2a_conv_block_res_conv[0][0]'] 
 malization)                                                                                      
                                                                                                  
 2a_conv_block_add (Add)        (None, 55, 55, 256)  0           ['2a_conv_block_bn3[0][0]',      
                                                                  '2a_conv_block_res_bn[0][0]']   
                                                                                                  
 2a_conv_block_relu4 (Activatio  (None, 55, 55, 256)  0          ['2a_conv_block_add[0][0]']      
 n)                                                                                               
                                                                                                  
 2b_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2a_conv_block_relu4[0][0]']    
 D)                                                                                               
                                                                                                  
 2b_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv1[0][0]']
 rmalization)                                                                                     
                                                                                                  
 2b_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn1[0][0]']  
 ation)                                                                                           
                                                                                                  
 2b_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2b_identity_block_relu1[0][0]']
 D)                                                                                               
           
*******略,大致意思是有50层,因为是ResNet50*********                                                                                         
                                                                                                  
 5c_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5c_identity_block_conv3[0][0]']
 rmalization)                                                                                     
                                                                                                  
 5c_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5c_identity_block_bn3[0][0]',  
                                                                  '5b_identity_block_relu4[0][0]']
                                                                                                  
 5c_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5c_identity_block_add[0][0]']  
 ation)                                                                                           
                                                                                                  
 avg_pool (AveragePooling2D)    (None, 1, 1, 2048)   0           ['5c_identity_block_relu4[0][0]']
                                                                                                  
 flatten (Flatten)              (None, 2048)         0           ['avg_pool[0][0]']               
                                                                                                  
 fc1000 (Dense)                 (None, 1000)         2049000     ['flatten[0][0]']                
                                                                                                  
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

在这里插入图片描述

训练结果

Epoch 1/200
625/625 [==============================] - 245s 376ms/step - loss: 0.3225 - accuracy: 0.8965 - val_loss: 0.7825 - val_accuracy: 0.7670
Epoch 2/200
625/625 [==============================] - 226s 361ms/step - loss: 0.2998 - accuracy: 0.9018 - val_loss: 0.9110 - val_accuracy: 0.7060
Epoch 3/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2644 - accuracy: 0.9113 - val_loss: 1.2283 - val_accuracy: 0.6760
Epoch 4/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2465 - accuracy: 0.9188 - val_loss: 0.9871 - val_accuracy: 0.7500
Epoch 5/200
625/625 [==============================] - 225s 360ms/step - loss: 0.2307 - accuracy: 0.9234 - val_loss: 1.1059 - val_accuracy: 0.6720
Epoch 6/200
625/625 [==============================] - 228s 365ms/step - loss: 0.2016 - accuracy: 0.9341 - val_loss: 0.5819 - val_accuracy: 0.8370
Epoch 7/200
625/625 [==============================] - 227s 363ms/step - loss: 0.1859 - accuracy: 0.9380 - val_loss: 0.8662 - val_accuracy: 0.7740
Epoch 8/200
625/625 [==============================] - 223s 356ms/step - loss: 0.1732 - accuracy: 0.9419 - val_loss: 0.6927 - val_accuracy: 0.8130
Epoch 9/200
625/625 [==============================] - 221s 353ms/step - loss: 0.1631 - accuracy: 0.9446 - val_loss: 0.7033 - val_accuracy: 0.8090
Epoch 10/200
625/625 [==============================] - 220s 352ms/step - loss: 0.1416 - accuracy: 0.9528 - val_loss: 0.8072 - val_accuracy: 0.8130
Epoch 11/200
625/625 [==============================] - 221s 352ms/step - loss: 0.1403 - accuracy: 0.9524 - val_loss: 0.9740 - val_accuracy: 0.7570
        loss  accuracy  val_loss  val_accuracy
0   0.322520   0.89655  0.782460         0.767
1   0.299801   0.90180  0.910993         0.706
2   0.264406   0.91130  1.228291         0.676
3   0.246451   0.91880  0.987062         0.750
4   0.230691   0.92340  1.105917         0.672
5   0.201550   0.93405  0.581927         0.837
6   0.185873   0.93800  0.866195         0.774
7   0.173195   0.94185  0.692714         0.813
8   0.163083   0.94460  0.703307         0.809
9   0.141572   0.95280  0.807219         0.813
10  0.140329   0.95240  0.974021         0.757

训练结果

从输出的训练结果来看,效果没有ResNet50V2的,对参数进行了一些调整,没有太多的效果。各位如果有兴趣可以对这样的网络进行修改,从而提升验证的正确性。

模型验证

模型验证代码

from keras.models import load_model
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
import numpy as np
import os

import matplotlib.pyplot as plt

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'

class_names = sorted(os.listdir(root_path))

model = load_model('./ResNet50_V1.h5')
model.summary()

def load_image(path):
    '''This function will load the image present at the given location'''
    image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (224,224)), tf.float32)
    #image = tf.cast(tf.image.resize(img_to_array(load_img(path)) / 255., (224, 224)), tf.float32)
    return image

i_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
image = load_image(i_path)
preds = model.predict(image[np.newaxis, ...])[0]

print(preds)

pred_class = class_names[np.argmax(preds)]

confidence_score = np.round(preds[np.argmax(preds)], 2)

# Configure Title
title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
print(title)

plt.figure(figsize=(25, 8))
plt.title(title)
plt.imshow(image)
plt.show()

while True:
    path =  input("input:")
    if (path == "q!"):
        exit()
    image = load_image(path)

    preds = model.predict(image[np.newaxis, ...])[0]
    print(preds)

    pred_class = class_names[np.argmax(preds)]

    confidence_score = np.round(preds[np.argmax(preds)], 2)

    # Configure Title
    title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
    print(title)

    plt.figure(figsize=(25, 8))
    plt.title(title)
    plt.imshow(image)
    plt.show()

相关文章:

  • c语言中常用的字符串处理函数总结
  • ESP-01S使用AT指令连接阿里云
  • 第十四届蓝桥杯模拟赛第二期部分题答案(C++代码)
  • 面试半年,上个月成功拿到阿里offer,全靠我啃烂了学长给的这份笔记
  • 【RTS】安海波老师:SIP与RTC融合分享笔记
  • 网站都变成灰色了,它是怎么实现的?
  • JavaWeb中文件上传与下载
  • 信奥赛一本通题解目录(未做完)
  • YOLO系列算法改进方法 | 目录一览表
  • 粒子群算法和鲸鱼算法的比较(Matlab代码实现)
  • HTML5期末大作业:HTM+CSS+JS仿安徽开放大学官网(web前端网页制作课作业)
  • C语言:动态内存分配(3)
  • 基于纳芯微产品的尾灯方案介绍
  • 设置程序以管理员权限运行无效问题的排查过程分享
  • MySQL密码不要用0开头!!!
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • Electron入门介绍
  • IIS 10 PHP CGI 设置 PHP_INI_SCAN_DIR
  • Java比较器对数组,集合排序
  • Octave 入门
  • Shadow DOM 内部构造及如何构建独立组件
  • Spring核心 Bean的高级装配
  • vue从创建到完整的饿了么(11)组件的使用(svg图标及watch的简单使用)
  • 阿里云应用高可用服务公测发布
  • 基于Vue2全家桶的移动端AppDEMO实现
  • 如何合理的规划jvm性能调优
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • [地铁译]使用SSD缓存应用数据——Moneta项目: 低成本优化的下一代EVCache ...
  • Spring Batch JSON 支持
  • 移动端高清、多屏适配方案
  • %3cscript放入php,跟bWAPP学WEB安全(PHP代码)--XSS跨站脚本攻击
  • ( 10 )MySQL中的外键
  • (2)关于RabbitMq 的 Topic Exchange 主题交换机
  • (C#)一个最简单的链表类
  • (Python) SOAP Web Service (HTTP POST)
  • (二)springcloud实战之config配置中心
  • (附源码)springboot 房产中介系统 毕业设计 312341
  • (三) diretfbrc详解
  • (原創) 如何讓IE7按第二次Ctrl + Tab時,回到原來的索引標籤? (Web) (IE) (OS) (Windows)...
  • .NET 2.0中新增的一些TryGet,TryParse等方法
  • .net 获取url的方法
  • .NET 中使用 TaskCompletionSource 作为线程同步互斥或异步操作的事件
  • .NET命名规范和开发约定
  • .NET企业级应用架构设计系列之结尾篇
  • @font-face 用字体画图标
  • @RunWith注解作用
  • @Transactional注解下,循环取序列的值,但得到的值都相同的问题
  • [ SNOI 2013 ] Quare
  • [BPU部署教程] 教你搞定YOLOV5部署 (版本: 6.2)
  • [bzoj2957]楼房重建
  • [bzoj4240] 有趣的家庭菜园
  • [javaSE] 看知乎学习工厂模式
  • [JS] node.js 入门
  • [Linux](16)网络编程:网络概述,网络基本原理,套接字,UDP,TCP,并发服务器编程,守护(精灵)进程
  • 百度搜索URL参数
  • BITMAP 位图
  • 字符串处理算法(六)求2个字符串最长公共部分
  • c json解析/创建详解及代码示例
  • ping 命令说明-查看谁占用ip + ping过程解析
  • 多线程(16) pthread_once
  • Oracle官方文档下载与旧版本下载
  • 多线程(15) pthread_setcancelstate 及 其他线程取消函数
  • 多线程(17) pthread_sigmask 及 线程安全和可重入
  • 免费游戏软件开发工具,做自己的游戏