当前位置: 首页 > news >正文

Tensorflow pb模型转tflite,并量化

一、tensorflow2.x版本pb模型转换tflite及量化

1、h5模型转tflite,不进行量化

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')

### 不量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_model_file = Path("mnist_model_null.tflite")
tflite_model_file.write_bytes(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

`

2、h5模型转tflite,进行动态范围量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')
### 动态范围量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_dynamic = converter.convert()
tflite_model_file = Path("mnist_model_dynamic.tflite")
tflite_model_file.write_bytes(tflite_model_dynamic)

interpreter = tf.lite.Interpreter(model_content=tflite_model_dynamic)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

`

3、h5模型转tflite,进行int8整型量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
print(type(train_images), train_images.shape)
train_images = train_images.astype(np.float32) / 255.0
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_int8 = converter.convert()
tflite_model_file = Path("mnist_model_int8.tflite")
tflite_model_file.write_bytes(tflite_model_int8)

interpreter = tf.lite.Interpreter(model_content=tflite_model_int8)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

`

4、h5模型转tflite,进行float16量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')

# float16量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model_float16 = converter.convert()
tflite_model_file = Path("mnist_model_float16.tflite")
tflite_model_file.write_bytes(tflite_model_float16)

interpreter = tf.lite.Interpreter(model_content=tflite_model_float16)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

`

二、tensorflow2.x版本调用1.x(.compat.v1)pb模型转换tflite及量化 (官方api)

1、pb模型转tflite,不进行量化

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
tflite_model = converter.convert()
open("model_null.tflite", "wb").write(tflite_model)
interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)

2、pb模型转tflite,进行动态范围量化

#  动态量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
tflite_model = converter.convert()
open("model_dynamic.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)

3、pb模型转tflite,进行int8整型量化

 # 整型量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
converter.inference_type = tf.int8
tflite_model = converter.convert()
open("model_int8.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)

4、pb模型转tflite,进行float16量化

#  float16量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
converter.inference_type = tf.float16
tflite_model = converter.convert()
open("model_float16.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)

·

三、调用tflite

import os
import cv2
import time
import numpy as np
import tensorflow as tf
from PIL import Image

#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy


# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_null.tflite")
interpreter = tf.compat.v1.lite.Interpreter(model_path="model_int8.tflite")
# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_float16.tflite")
# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_dynamic.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

test_image = cv2.imread('test.png')                             # (1080, 1920, 3)
r_w, r_h = 512, 256
img_data =  cv2.resize(test_image, (r_w, r_h))                  # (256, 512, 3)
img_data = np.expand_dims(img_data, axis=0).astype(np.int8)

interpreter.set_tensor(input_details[0]['index'], img_data)
interpreter.set_tensor(input_details[1]['index'], [False])
t1 = time.time()
interpreter.invoke()
t2 = time.time()
prediction = interpreter.get_tensor(output_details[0]['index'])
print(t2-t1)

print(prediction.shape)
prediction = prediction[0]
print(prediction.shape)

prediction1 = prediction[:,:,0]
print(prediction1.shape)
print(np.max(prediction1),np.min(prediction1))
img = Image.fromarray(prediction1)
img.show()

prediction2 = prediction[:,:,1]
print(prediction2.shape)
print(np.max(prediction2),np.min(prediction2))
img = Image.fromarray(prediction2)
img.show()

`

四、参考

1、官方转换教程参考
2、tensorflow 将.pb文件量化操作为.tflite
3、tensorflow2转tflite提示OP不支持的解决方案
4、Tensorflow2 lite 模型量化

相关文章:

  • 【PTHREAD】线程状态
  • 网易云音乐项目————项目准备
  • 计算机网络——应用层の选择题整理
  • LabVIEW通过网络传输数据
  • 【PTHREAD】线程属性
  • 如何做好项目管理?项目管理和团队协作是关键
  • 《嵌入式 – GD32开发实战指南》第20章 GD32的存储结构
  • Vue模块语法上(插值指令过滤器计算属性-监听属性)
  • 初识网络
  • Linux的OpenLava配置
  • MySQL如何记忆
  • 【回溯算法】leetcode 78. 子集
  • stm32f4xx-外部中断
  • Tricentis NeoLoad:自动化的企业性能测试平台
  • Linux内核中网络部分结构以及分布
  • 《Java8实战》-第四章读书笔记(引入流Stream)
  • 5、React组件事件详解
  • export和import的用法总结
  • javascript面向对象之创建对象
  • js写一个简单的选项卡
  • laravel5.5 视图共享数据
  • LeetCode18.四数之和 JavaScript
  • PHP 程序员也能做的 Java 开发 30分钟使用 netty 轻松打造一个高性能 websocket 服务...
  • RedisSerializer之JdkSerializationRedisSerializer分析
  • Sublime Text 2/3 绑定Eclipse快捷键
  • Vue实战(四)登录/注册页的实现
  • windows下使用nginx调试简介
  • 工作中总结前端开发流程--vue项目
  • 基于Mobx的多页面小程序的全局共享状态管理实践
  • 警报:线上事故之CountDownLatch的威力
  • 跨域
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 如何抓住下一波零售风口?看RPA玩转零售自动化
  • 山寨一个 Promise
  • 使用 QuickBI 搭建酷炫可视化分析
  • 使用docker-compose进行多节点部署
  • 微信小程序开发问题汇总
  • 智能网联汽车信息安全
  • ​一些不规范的GTID使用场景
  • #{}和${}的区别?
  • #{}和${}的区别是什么 -- java面试
  • #QT项目实战(天气预报)
  • (11)MSP430F5529 定时器B
  • (LNMP) How To Install Linux, nginx, MySQL, PHP
  • (附源码)springboot掌上博客系统 毕业设计063131
  • .babyk勒索病毒解析:恶意更新如何威胁您的数据安全
  • .CSS-hover 的解释
  • .FileZilla的使用和主动模式被动模式介绍
  • .helper勒索病毒的最新威胁:如何恢复您的数据?
  • .Net CF下精确的计时器
  • .net 桌面开发 运行一阵子就自动关闭_聊城旋转门家用价格大约是多少,全自动旋转门,期待合作...
  • .NET/C# 编译期能确定的字符串会在字符串暂存池中不会被 GC 垃圾回收掉
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)...
  • .NetCore Flurl.Http 升级到4.0后 https 无法建立SSL连接
  • .netcore如何运行环境安装到Linux服务器