当前位置: 首页 > news >正文

机器学习笔记 - 在QT/PyTorch/C++ 中加载 TORCHSCRIPT 模型

一、配置QT+libtorch环境

        1、首先下载pytorch,注意版本要和你的python下的pytorch版本,否则运行脚本时会报错,我这里用的是windows下的cuda11.6,release版本,有debug和release两个版本可以下载。

        2、 将压缩包内的include和lib文件夹放到指定位置,我这里是放到了项目文件夹内。

        3、因为QT项目用的是CMakeLists.txt文件,所以添加如下代码。

//两个包含路径
include_directories(${CMAKE_SOURCE_DIR}/lib/torch/include)
include_directories(${CMAKE_SOURCE_DIR}/lib/torch/include/torch/csrc/api/include)

//这里可以把所有的包都引入,我比较懒,暂时这里只引了几个
target_link_libraries(QtFFmpegApp2 PRIVATE ${CMAKE_SOURCE_DIR}/lib/torch/lib/asmjit.lib
                                           ${CMAKE_SOURCE_DIR}/lib/torch/lib/c10.lib
                                           ${CMAKE_SOURCE_DIR}/lib/torch/lib/torch.lib
                                           ${CMAKE_SOURCE_DIR}/lib/torch/lib/torch_cpu.lib
                                           ${CMAKE_SOURCE_DIR}/lib/torch/lib/torch_cuda.lib
                                       )

        4、测试环境,这里需要注意引入包含文件的时候按照如下写法,应该是slots有冲突。

//torch
#undef slots
#include <torch/torch.h>
#define slots Q_SLOTS


torch::Tensor tensor = torch::rand({ 5,3 });
std::cout << tensor << std::endl;

        输出如下

 0.7113  0.9045  0.9027
 0.2154  0.1694  0.4926
 0.1489  0.6318  0.7042
 0.2531  0.5301  0.2073
 0.6361  0.4789  0.3916
[ CPUFloatType{5,3} ]

二、将PyTorch模型转换为Torch脚本

        这里直接实用官方的alexnet模型,测试过一些自定义的模型,但是总是失败,还没研究明白,后面再试。

import torch
import torchvision


alexnet = torchvision.models.alexnet(pretrained=True)
dummy_input = torch.rand((1, 3, 224, 224))

traced_alexnet = torch.jit.trace(alexnet, dummy_input)
torch.jit.save(traced_alexnet, "alexnet.pt")

三、在 C++ 中加载脚本模块

#undef slots
#include <torch/torch.h>
#include <torch/script.h>
#define slots Q_SLOTS


void MainWindow::on_pushButton_2_clicked()
{
    //实用opencv读取图片
    cv::Mat img = cv::imread("0.jpg");
    //进行图片缩放
    cv::resize(img, img, cv::Size(224, 224), 0, 0, 1);
    //转为tensor_image 
    auto tensor_image = torch::from_blob(img.data, { img.rows, img.cols, img.channels() }, at::kByte);
    //将tensor的维度换位
    tensor_image = tensor_image.permute({ 2,0,1 });


    auto mean1 = torch::tensor({0.485}).repeat({224, 224});
    auto mean2 = torch::tensor({0.456}).repeat({224, 224});
    auto mean3 = torch::tensor({0.406}).repeat({224, 224});
    auto mean = torch::stack({mean1, mean2, mean3});

    auto std1 = torch::tensor({0.229}).repeat({224, 224});
    auto std2 = torch::tensor({0.224}).repeat({224, 224});
    auto std3 = torch::tensor({0.225}).repeat({224, 224});
    auto std = torch::stack({std1, std2, std3});

    tensor_image = tensor_image / 255.0;
    tensor_image = (tensor_image - mean) / std;

    tensor_image.unsqueeze_(0);

    auto image = tensor_image.to(torch::kFloat);
    
    //如果load失败大概率是因为用Python生成.pt文件时所采用的pytorch的版本和C++采用的Libtorch版本不一致
    torch::jit::script::Module alexnet = torch::jit::load("alexnet.pt");

    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(image);

    auto output = alexnet.forward(inputs).toTensor();
    auto prediction_index = output.argmax();
    
    //这里应该是使用对应的imagenet得到分类名
    //std::cout << "Prediction: " << id2class[prediction_index.item<int>()] << std::endl;
    
    //不过懒得弄,直接打印了索引出来
    std::cout << "Prediction: " << prediction_index.item<int>() << std::endl;
}

        这是一只可爱的猫咪,推理完成之后得到的索引是281。在imagenet的分类索引里面是281 n02123045 猫, tabby, tabby cat。分类正确 

        下面的车识别之后是751 n04037443 赛车, racer, race car, racing car ,也还是准确的。

        试过一些其它的也有不准确的。

相关文章:

  • redis 技术分享
  • 怎么让面试官喜欢你?
  • 深度学习模型理解-CNN-手写数据字代码
  • C# ZBar解码测试(QRCode、一维码条码)并记录里面隐藏的坑
  • 【技术美术图形部分】图形渲染管线3.0-光栅化和像素处理阶段
  • css:一个容器(页面),里面有两个div左右摆放并且高度和容器高度一致,左div不会随着页面左右伸缩而变化,右div随页面左右伸缩宽度自适应(手写)
  • Kubernetes 1.25 集群搭建
  • 【每周CV论文推荐】GAN在医学图像生成与增强中的典型应用
  • python毕业设计项目源码选题(16)跳蚤市场二手物品交易系统毕业设计毕设作品开题报告开题答辩PPT
  • C# 连接 SqlServer 数据库
  • 【408计算机组成原理】—进位计数制(二)
  • 拆解一下汽车电子软件开发工具链
  • 2021年华中杯数学建模挑战赛B题技术问答社区重复问题识别求解全过程文档及程序
  • 【云原生 | 32】Docker运行数据采集和分析引擎Elasticsearch
  • 有符号数四舍五入的verilog实现
  • 分享的文章《人生如棋》
  • 【附node操作实例】redis简明入门系列—字符串类型
  • 【前端学习】-粗谈选择器
  • Apache的基本使用
  • C++入门教程(10):for 语句
  • Facebook AccountKit 接入的坑点
  • iOS筛选菜单、分段选择器、导航栏、悬浮窗、转场动画、启动视频等源码
  • nginx(二):进阶配置介绍--rewrite用法,压缩,https虚拟主机等
  • ReactNativeweexDeviceOne对比
  • React系列之 Redux 架构模式
  • Stream流与Lambda表达式(三) 静态工厂类Collectors
  • Yeoman_Bower_Grunt
  • 从setTimeout-setInterval看JS线程
  • 对话:中国为什么有前途/ 写给中国的经济学
  • 汉诺塔算法
  • 基于OpenResty的Lua Web框架lor0.0.2预览版发布
  • 那些年我们用过的显示性能指标
  • 深入 Nginx 之配置篇
  • 深入浏览器事件循环的本质
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 算法-插入排序
  • 突破自己的技术思维
  • 网络应用优化——时延与带宽
  • 运行时添加log4j2的appender
  • Spring Batch JSON 支持
  • 蚂蚁金服CTO程立:真正的技术革命才刚刚开始
  • ​MySQL主从复制一致性检测
  • #预处理和函数的对比以及条件编译
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (3)llvm ir转换过程
  • (3)nginx 配置(nginx.conf)
  • (33)STM32——485实验笔记
  • (BFS)hdoj2377-Bus Pass
  • (C语言)球球大作战
  • (poj1.3.2)1791(构造法模拟)
  • (ZT) 理解系统底层的概念是多么重要(by趋势科技邹飞)
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (七)c52学习之旅-中断
  • (四)Linux Shell编程——输入输出重定向