当前位置: 首页 > news >正文

机器学习笔记 - 使用TensorFlow Lite从头创建模型

一、概述

        目标是使用 TF Lite Model Maker Library 创建 TensorFlow Lite 模型。将在自定义数据集上微调预训练的图像分类模型,并进一步探索该库当前支持的不同类型的模型优化技术,并将其导出到 TF Lite 模型。最后对创建的 TF Lite 模型和转换后的模型进行性能比较。

        TensorFlow Lite 模型制作库使我们能够在自定义数据集上训练预训练或自定义 TensorFlow Lite 模型。在为设备上的 ML 应用程序部署 TensorFlow 神经网络模型时,它简化了调整模型并将其转换为特定输入数据的过程。目前,它支持图像分类、物体检测、文本分类、BERT 问题解答、音频分类、推荐系统等。

        安装方法1

pip install tflite-model-maker

        安装方法2

git clone https://github.com/tensorflow/examples
cd examples/tensorflow_examples/lite/model_maker/pip_package
pip install -e .

二、准备数据集

        这里使用DataLoader来加载数据集。

#Importing libraries
from PIL import Image
import glob
import os
from pathlib import Path
 
#Converting images in cat folder to png format
current_dir = Path('/content/PetImages/Cat').resolve()
outputdir = Path('/content/Dataset').resolve()
out_dir = outputdir / "Cat"
os.mkdir(out_dir)
cnt = 0
 
for img in glob.glob(str(current_dir / "*.jpg")):
    filename = Path(img).stem
    Image.open(img).save(str(out_dir / f'{filename}.png'))
    cnt = cnt + 1
    print(cnt)
 
#Converting images in dog folder to png format
current_dir = Path('/content/PetImages/Dog/').resolve()
outputdir = Path('/content/Dataset/').resolve()
out_dir = outputdir / "Dog"
os.mkdir(out_dir)
cnt = 0
 
for img in glob.glob(str(current_dir / "*.jpg")):
    filename = Path(img).stem
    Image.open(img).convert('RGB').save(str(out_dir / f'{filename}.png'))
    cnt = cnt + 1
    print(cnt)

        加载数据集

#Loading dataset using the Dataloader
data = DataLoader.from_folder('/content/Dataset')

        将数据集以 7:2:1 的比例分别拆分为训练集、验证集和测试集。

#Splitting dataset into training, validation and testing data
train_data, rest_data = data.split(0.7)
validation_data, test_data = rest_data.split(0.67)

三、模型训练

        这里重新训练EfficientNet Lite 0 模型。它在 Imagenet (ILSVRC-2012-CLS) 上进行了训练,针对 TFLite 进行了优化,并针对移动 CPU、GPU 和 EdgeTPU 的性能而设计。由于边缘设备的要求,对原始 EfficientNets 进行了以下更改:

        删除了squeeze-and-excite  blocks(SE),因为某些移动加速器不能很好地支持 SE。

        用 RELU6 替换了所有的 swish,以便于后量化。

        在放大模型时固定主干和头部,以保持模型足够小。

        使用image_classifier.create()函数创建模型。这model_spec() 有助于我们指定我们将使用model_spec.get()函数来导入预训练模型的图像模型。我们将分别传递 train_data 和 validation_data 作为训练和验证数据集。此外,我们将train_whole_model设置为 true 以便重新训练整个模型。其他各种参数image_classifier.create() 根据要求。我们让其余的参数保持默认值。

#Training the model
model = image_classifier.create(train_data, model_spec=model_spec.get('efficientnet_lite0'), validation_data=validation_data, train_whole_model=True)

        进行训练

INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_1"
_________________________________________________________________
Layer (type)                Output Shape              Param #  
=================================================================
hub_keras_layer_v1v2_1 (Hub  (None, 1280)             3413024 
KerasLayerV1V2)                                                
                                                                
dropout_1 (Dropout)         (None, 1280)              0       
                                                                
dense_1 (Dense)             (None, 2)                 2562    
                                                                
=================================================================
Total params: 3,415,586
Trainable params: 2,562
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/gradient_descent.py:102: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(SGD, self).__init__(name, **kwargs)
546/546 [==============================] - 2586s 5s/step - loss: 0.2463 - accuracy: 0.9812 - val_loss: 0.2281 - val_accuracy: 0.9899
Epoch 2/5
546/546 [==============================] - 151s 277ms/step - loss: 0.2299 - accuracy: 0.9898 - val_loss: 0.2266 - val_accuracy: 0.9900
Epoch 3/5
546/546 [==============================] - 151s 276ms/step - loss: 0.2271 - accuracy: 0.9908 - val_loss: 0.2258 - val_accuracy: 0.9906
Epoch 4/5
546/546 [==============================] - 153s 281ms/step - loss: 0.2264 - accuracy: 0.9916 - val_loss: 0.2243 - val_accuracy: 0.9902
Epoch 5/5
546/546 [==============================] - 153s 280ms/step - loss: 0.2258 - accuracy: 0.9909 - val_loss: 0.2259 - val_accuracy: 0.9904

        使用model.evaluate()函数在测试数据集上评估模型

loss, accuracy = model.evaluate(test_data)
78/78 [==============================] - 341s 4s/step - loss: 0.2246 - accuracy: 0.9911

四、模型优化

1、FP 16 量化

        可以使用该model.export()函数将模型导出到 Float-16 TF Lite 模型。在这里,我们将定义量化为 Float 16 的配置。然后我们将在测试数据集上评估导出的量化模型。

#Defining Config
config = QuantizationConfig.for_float16()
 
#Exporting Model
model.export(export_dir='/content/Models/', tflite_filename='model_fp16.tflite', quantization_config=config)
 
#Evaluating Exported Model
model.evaluate_tflite('/content//Models/model_fp16.tflite', test_data)
{'accuracy': 0.9911111111111112}

2、动态量化

#Defining Config
config = QuantizationConfig.for_dynamic()
#Exporting Model
model.export(export_dir='/content/Models/', tflite_filename='model_dynamic.tflite', quantization_config=config)
#Evaluating Exported Model
model.evaluate_tflite('/content/Models/model_dynamic.tflite', test_data)
{'accuracy': 0.9919191919191919}

3、整数量化

#Defining Config
config = QuantizationConfig.for_int8(test_data)
#Exporting Model
model.export(export_dir='/content/Models/', tflite_filename='model_int8.tflite', quantization_config=config)
#Evaluating Exported Model
model.evaluate_tflite('/content/model_int8.tflite', test_data)
{'accuracy': 0.9915151515151515}

五、性能比较

        可以看到 FP-16 量化模型的准确性略有提高。但是可以看到在整数量化模型的情况下准确度有了显着提高。 

        创建的 TF Lite 模型比转换后的模型更小。 

 

         推理时间也有不同程度的减少。

相关文章:

  • 直流信号隔离采样
  • 锐捷——RIP基础配置
  • 一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
  • 啃完这些 Spring 知识点,我竟吊打了阿里面试官(附面经 + 笔记)
  • Java集合List接口详解——含源码分析
  • 自动化测试怎么做?python自动化测试断言详细实战代码(看这一篇就够了)
  • 100天精通Python(数据分析篇)——第61天:Pandas.to_datetime函数(处理时间)
  • C#多线程学习总结
  • 网络安全漏洞原理利用与渗透
  • 【DDR3 控制器设计】(4)DDR3 的读操作设计
  • 基于OpenCV的单目相机标定与三维定位(推广)
  • Java数据结构:单链表的实现与面试题汇总
  • 2022年都说软件测试不香了?在职3年月薪16k我满意了,你们觉得前景怎么样?
  • python做了个自动关机工具,再也不会耽误我下班啦
  • BUUCTF NewStarCTF 公开赛赛道Week5 Writeup
  • IE9 : DOM Exception: INVALID_CHARACTER_ERR (5)
  • 【Linux系统编程】快速查找errno错误码信息
  • Centos6.8 使用rpm安装mysql5.7
  • java小心机(3)| 浅析finalize()
  • node-sass 安装卡在 node scripts/install.js 解决办法
  • zookeeper系列(七)实战分布式命名服务
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​决定德拉瓦州地区版图的关键历史事件
  • #define
  • #Spring-boot高级
  • (附源码)spring boot球鞋文化交流论坛 毕业设计 141436
  • (附源码)计算机毕业设计SSM智能化管理的仓库管理
  • (转)EXC_BREAKPOINT僵尸错误
  • .NET Framework 4.6.2改进了WPF和安全性
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • @Conditional注解详解
  • @DataRedisTest测试redis从未如此丝滑
  • @data注解_一枚 架构师 也不会用的Lombok注解,相见恨晚
  • @PreAuthorize注解
  • [ 攻防演练演示篇 ] 利用通达OA 文件上传漏洞上传webshell获取主机权限
  • [1181]linux两台服务器之间传输文件和文件夹
  • [AutoSar]BSW_Com02 PDU详解
  • [C++]AVL树怎么转
  • [codeforces]Levko and Permutation
  • [codevs1288] 埃及分数
  • [DM复习]关联规则挖掘(下)
  • [excel与dict] python 读取excel内容并放入字典、将字典内容写入 excel文件
  • [ICCV2017]Neural Person Search Machines
  • [jQuery]使用jQuery.Validate进行客户端验证(中级篇-上)——不使用微软验证控件的理由...
  • [python]基本输出输入函数
  • [Redis]Redis的数据类型
  • [Redis]基础入门
  • [UML]UML系列——类图class的实现关系Realization
  • [vue3] 富文本
  • [word] word艺术字体如何设置? #知识分享#职场发展#媒体
  • [单片机框架][drivers层][cw2015] fuelgauge 硬件电量计(二)
  • [导入]上传大文件时,找不到服务器的错误问题!
  • [更新]ARCGIS之土地耕地占补平衡、进出平衡系统报备坐标txt格式批量导出工具(定制开发版)
  • [国嵌攻略][051][NandFlash原理解析]
  • [黑马程序员Pandas教程]——Pandas常用计算函数