当前位置: 首页 > news >正文

基于keras与tensorflow手工实现ResNet50网络

前言

在文章 基于tensorflow的ResNet50V2网络识别动物,我们使用了keras已经提供的神经网络,完成了图像分类的。这个时候,小明同学就问了,那么我怎么自己去写一个神经网络来进行训练呢?
本文就基于tensorflow,自己定一个神经网络。

ResNet50网络

在这里插入图片描述
从结构上看,与我们之前的的区别在于,输入的格式变成(3,224,224)
ResNet50有两个基本的块,分别名为Conv BlockIdentity Block

整体架构

在这里插入图片描述

Conv Block架构

在这里插入图片描述

Identity Block架构

在这里插入图片描述

模型训练

手工实现模型

模型代码(resnet50.py)

# 根据模型进行引入
from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor], name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    return model

模型训练(resnet50_model_train.py)

将 基于tensorflow的ResNet50V2网络识别动物的模型训练代码进行少量改造

import os
import pandas as pd

# Model
import keras
from keras.preprocessing.image import ImageDataGenerator

# Callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Pre-Trained Model
import tensorflow as tf
import resnet50

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)

print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")


class_dis = [len(os.listdir(root_path+name)) for name in class_names]


train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
valid_gen = ImageDataGenerator(rescale=1/255.)
test_gen = ImageDataGenerator(rescale=1/255)

# Load Data
train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)



with tf.device("/GPU:0"):
    ## Pre-Trained Model
    model = resnet50.ResNet50()
    model.summary()

    model_file = "ResNet50_V1.h5"
    # 加载预训练模型
    if os.path.exists(model_file):
        model.load_weights(model_file)

    # Callbacks
    cbs = [EarlyStopping(patience=5, restore_best_weights=True), ModelCheckpoint(model_file, save_best_only=True)]

    # Model
    opt = tf.keras.optimizers.Adam(learning_rate=2e-3)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

    # Model Training
    history = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=200)

    data = pd.DataFrame(history.history)
    print(data)

模型执行

GPU基本跑满
在这里插入图片描述

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d (ZeroPadding2D)  (None, 230, 230, 3)  0          ['input_1[0][0]']                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 64  9472        ['zero_padding2d[0][0]']         
                                )                                                                 
                                                                                                  
 bn_conv1 (BatchNormalization)  (None, 112, 112, 64  256         ['conv1[0][0]']                  
                                )                                                                 
                                                                                                  
 activation (Activation)        (None, 112, 112, 64  0           ['bn_conv1[0][0]']               
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 55, 55, 64)   0           ['activation[0][0]']             
                                                                                                  
 2a_conv_block_conv1 (Conv2D)   (None, 55, 55, 64)   4160        ['max_pooling2d[0][0]']          
                                                                                                  
 2a_conv_block_bn1 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv1[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu1 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn1[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv2 (Conv2D)   (None, 55, 55, 64)   36928       ['2a_conv_block_relu1[0][0]']    
                                                                                                  
 2a_conv_block_bn2 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv2[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu2 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn2[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv3 (Conv2D)   (None, 55, 55, 256)  16640       ['2a_conv_block_relu2[0][0]']    
                                                                                                  
 2a_conv_block_res_conv (Conv2D  (None, 55, 55, 256)  16640      ['max_pooling2d[0][0]']          
 )                                                                                                
                                                                                                  
 2a_conv_block_bn3 (BatchNormal  (None, 55, 55, 256)  1024       ['2a_conv_block_conv3[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_res_bn (BatchNor  (None, 55, 55, 256)  1024       ['2a_conv_block_res_conv[0][0]'] 
 malization)                                                                                      
                                                                                                  
 2a_conv_block_add (Add)        (None, 55, 55, 256)  0           ['2a_conv_block_bn3[0][0]',      
                                                                  '2a_conv_block_res_bn[0][0]']   
                                                                                                  
 2a_conv_block_relu4 (Activatio  (None, 55, 55, 256)  0          ['2a_conv_block_add[0][0]']      
 n)                                                                                               
                                                                                                  
 2b_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2a_conv_block_relu4[0][0]']    
 D)                                                                                               
                                                                                                  
 2b_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv1[0][0]']
 rmalization)                                                                                     
                                                                                                  
 2b_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn1[0][0]']  
 ation)                                                                                           
                                                                                                  
 2b_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2b_identity_block_relu1[0][0]']
 D)                                                                                               
           
*******略,大致意思是有50层,因为是ResNet50*********                                                                                         
                                                                                                  
 5c_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5c_identity_block_conv3[0][0]']
 rmalization)                                                                                     
                                                                                                  
 5c_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5c_identity_block_bn3[0][0]',  
                                                                  '5b_identity_block_relu4[0][0]']
                                                                                                  
 5c_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5c_identity_block_add[0][0]']  
 ation)                                                                                           
                                                                                                  
 avg_pool (AveragePooling2D)    (None, 1, 1, 2048)   0           ['5c_identity_block_relu4[0][0]']
                                                                                                  
 flatten (Flatten)              (None, 2048)         0           ['avg_pool[0][0]']               
                                                                                                  
 fc1000 (Dense)                 (None, 1000)         2049000     ['flatten[0][0]']                
                                                                                                  
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

在这里插入图片描述

训练结果

Epoch 1/200
625/625 [==============================] - 245s 376ms/step - loss: 0.3225 - accuracy: 0.8965 - val_loss: 0.7825 - val_accuracy: 0.7670
Epoch 2/200
625/625 [==============================] - 226s 361ms/step - loss: 0.2998 - accuracy: 0.9018 - val_loss: 0.9110 - val_accuracy: 0.7060
Epoch 3/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2644 - accuracy: 0.9113 - val_loss: 1.2283 - val_accuracy: 0.6760
Epoch 4/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2465 - accuracy: 0.9188 - val_loss: 0.9871 - val_accuracy: 0.7500
Epoch 5/200
625/625 [==============================] - 225s 360ms/step - loss: 0.2307 - accuracy: 0.9234 - val_loss: 1.1059 - val_accuracy: 0.6720
Epoch 6/200
625/625 [==============================] - 228s 365ms/step - loss: 0.2016 - accuracy: 0.9341 - val_loss: 0.5819 - val_accuracy: 0.8370
Epoch 7/200
625/625 [==============================] - 227s 363ms/step - loss: 0.1859 - accuracy: 0.9380 - val_loss: 0.8662 - val_accuracy: 0.7740
Epoch 8/200
625/625 [==============================] - 223s 356ms/step - loss: 0.1732 - accuracy: 0.9419 - val_loss: 0.6927 - val_accuracy: 0.8130
Epoch 9/200
625/625 [==============================] - 221s 353ms/step - loss: 0.1631 - accuracy: 0.9446 - val_loss: 0.7033 - val_accuracy: 0.8090
Epoch 10/200
625/625 [==============================] - 220s 352ms/step - loss: 0.1416 - accuracy: 0.9528 - val_loss: 0.8072 - val_accuracy: 0.8130
Epoch 11/200
625/625 [==============================] - 221s 352ms/step - loss: 0.1403 - accuracy: 0.9524 - val_loss: 0.9740 - val_accuracy: 0.7570
        loss  accuracy  val_loss  val_accuracy
0   0.322520   0.89655  0.782460         0.767
1   0.299801   0.90180  0.910993         0.706
2   0.264406   0.91130  1.228291         0.676
3   0.246451   0.91880  0.987062         0.750
4   0.230691   0.92340  1.105917         0.672
5   0.201550   0.93405  0.581927         0.837
6   0.185873   0.93800  0.866195         0.774
7   0.173195   0.94185  0.692714         0.813
8   0.163083   0.94460  0.703307         0.809
9   0.141572   0.95280  0.807219         0.813
10  0.140329   0.95240  0.974021         0.757

训练结果

从输出的训练结果来看,效果没有ResNet50V2的,对参数进行了一些调整,没有太多的效果。各位如果有兴趣可以对这样的网络进行修改,从而提升验证的正确性。

模型验证

模型验证代码

from keras.models import load_model
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
import numpy as np
import os

import matplotlib.pyplot as plt

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'

class_names = sorted(os.listdir(root_path))

model = load_model('./ResNet50_V1.h5')
model.summary()

def load_image(path):
    '''This function will load the image present at the given location'''
    image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (224,224)), tf.float32)
    #image = tf.cast(tf.image.resize(img_to_array(load_img(path)) / 255., (224, 224)), tf.float32)
    return image

i_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
image = load_image(i_path)
preds = model.predict(image[np.newaxis, ...])[0]

print(preds)

pred_class = class_names[np.argmax(preds)]

confidence_score = np.round(preds[np.argmax(preds)], 2)

# Configure Title
title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
print(title)

plt.figure(figsize=(25, 8))
plt.title(title)
plt.imshow(image)
plt.show()

while True:
    path =  input("input:")
    if (path == "q!"):
        exit()
    image = load_image(path)

    preds = model.predict(image[np.newaxis, ...])[0]
    print(preds)

    pred_class = class_names[np.argmax(preds)]

    confidence_score = np.round(preds[np.argmax(preds)], 2)

    # Configure Title
    title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
    print(title)

    plt.figure(figsize=(25, 8))
    plt.title(title)
    plt.imshow(image)
    plt.show()

相关文章:

  • c语言中常用的字符串处理函数总结
  • ESP-01S使用AT指令连接阿里云
  • 第十四届蓝桥杯模拟赛第二期部分题答案(C++代码)
  • 面试半年,上个月成功拿到阿里offer,全靠我啃烂了学长给的这份笔记
  • 【RTS】安海波老师:SIP与RTC融合分享笔记
  • 网站都变成灰色了,它是怎么实现的?
  • JavaWeb中文件上传与下载
  • 信奥赛一本通题解目录(未做完)
  • YOLO系列算法改进方法 | 目录一览表
  • 粒子群算法和鲸鱼算法的比较(Matlab代码实现)
  • HTML5期末大作业:HTM+CSS+JS仿安徽开放大学官网(web前端网页制作课作业)
  • C语言:动态内存分配(3)
  • 基于纳芯微产品的尾灯方案介绍
  • 设置程序以管理员权限运行无效问题的排查过程分享
  • MySQL密码不要用0开头!!!
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 0x05 Python数据分析,Anaconda八斩刀
  • JS学习笔记——闭包
  • mac修复ab及siege安装
  • Phpstorm怎样批量删除空行?
  • Swoft 源码剖析 - 代码自动更新机制
  • ubuntu 下nginx安装 并支持https协议
  • vue-loader 源码解析系列之 selector
  • 回顾 Swift 多平台移植进度 #2
  • 精益 React 学习指南 (Lean React)- 1.5 React 与 DOM
  • 前端_面试
  • 前端自动化解决方案
  • 入职第二天:使用koa搭建node server是种怎样的体验
  • 扫描识别控件Dynamic Web TWAIN v12.2发布,改进SSL证书
  • 数据科学 第 3 章 11 字符串处理
  • 思否第一天
  • 听说你叫Java(二)–Servlet请求
  • 小程序01:wepy框架整合iview webapp UI
  • 移动端 h5开发相关内容总结(三)
  • 用简单代码看卷积组块发展
  • ​​​​​​​ubuntu16.04 fastreid训练过程
  • # MySQL server 层和存储引擎层是怎么交互数据的?
  • #HarmonyOS:软件安装window和mac预览Hello World
  • #Linux(Source Insight安装及工程建立)
  • $.proxy和$.extend
  • (windows2012共享文件夹和防火墙设置
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (四)模仿学习-完成后台管理页面查询
  • (转)树状数组
  • (轉貼) VS2005 快捷键 (初級) (.NET) (Visual Studio)
  • .Net 4.0并行库实用性演练
  • .net 使用$.ajax实现从前台调用后台方法(包含静态方法和非静态方法调用)
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)...
  • .NET设计模式(7):创建型模式专题总结(Creational Pattern)
  • @transactional 方法执行完再commit_当@Transactional遇到@CacheEvict,你的代码是不是有bug!...
  • [383] 赎金信 js
  • [Android View] 可绘制形状 (Shape Xml)
  • [cocos2d-x]关于CC_CALLBACK
  • [Django 0-1] Core.Checks 模块