当前位置: 首页 > news >正文

用自己的数据集训练TimeSformer并转ONNX用c++推理

用自己的数据集训练TimeSformer并转ONNX用c++推理

文章目录

  • 用自己的数据集训练TimeSformer并转ONNX用c++推理
    • 下载安装TimeSformer
    • 创建分类文件夹
    • 创建数据集
    • 修改训练配置
    • 运行脚本开始训练
    • 测试模型
    • 模型转为onnx
    • 测试一下生成的onnx模型
    • 转为用c++推理

下载安装TimeSformer

TimeSformer开源地址

按照官方教程安装好环境。
如果报下面这个错误,是因为新版的pytorch已经不支持那种写法了,需要修改一下。

ImportError: cannot import name '_LinearWithBias' from 'torch.nn.modules.linear'

可以参考这个人的fork修改

创建分类文件夹

我这里有61个动作分类,每个分类创建一个文件夹
在这里插入图片描述
将视频文件分割成 每个视频大概10s左右;
然后将视频文件按照分类放到每个文件夹里。

创建数据集

写一个脚本分割数据集,并生成标签文件

import os
import csv
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_splitout_dir = "/home/disk/liangbaikai/TimeSformer/mydata/mydatasets"  # 输出路径
video_path = "/home/disk/liangbaikai/TimeSformer/mydata/myvideos" # 数据集路径
file_name = ".csv"
name_list = ["train","test","val"]if not os.path.exists(out_dir):os.mkdir(out_dir)
if not os.path.exists(os.path.join(out_dir, 'train')):os.mkdir(os.path.join(out_dir, 'train'))
if not os.path.exists(os.path.join(out_dir, 'val')):os.mkdir(os.path.join(out_dir, 'val'))
if not os.path.exists(os.path.join(out_dir, 'test')):os.mkdir(os.path.join(out_dir, 'test'))for file in os.listdir(video_path):file_path = os.path.join(video_path, file)video_files = [name for name in os.listdir(file_path)]#将20%的数据分配给testtrain_and_valid, test = train_test_split(video_files, test_size=0.2, random_state=42)#将80%的数据再分配20%出来给val,剩下的给traintrain, val = train_test_split(train_and_valid, test_size=0.2, random_state=42)train_dir = os.path.join(out_dir, 'train', file)val_dir = os.path.join(out_dir, 'val', file)test_dir = os.path.join(out_dir, 'test', file)if not os.path.exists(train_dir):os.mkdir(train_dir)if not os.path.exists(val_dir):os.mkdir(val_dir)if not os.path.exists(test_dir):os.mkdir(test_dir)for video in tqdm(train):shutil.copy(os.path.join(video_path,file,video),os.path.join(train_dir,video))for video in tqdm(test):shutil.copy(os.path.join(video_path,file,video),os.path.join(test_dir,video))for video in tqdm(val):shutil.copy(os.path.join(video_path,file,video),os.path.join(val_dir,video))#输出路径下创建csv文件夹,并在文件夹下创建train.csv val.csv test.csv
csv_path = os.path.join(out_dir,"csv")
if not os.path.exists(csv_path):os.mkdir(csv_path)for name in name_list:with open(os.path.join(csv_path,name+file_name),'wb') as f:print("创建"+os.path.join(csv_path,name+file_name))for ii in os.listdir(csv_path):if ii.split(".")[0] in name_list:path1 = os.path.join(csv_path,ii)with open(path1, 'w', newline='') as f:for dd in os.listdir(out_dir):if dd==ii.split(".")[0]:for zz in os.listdir(os.path.join(out_dir,dd)):for mm in os.listdir(os.path.join(out_dir,dd,zz)):writer = csv.writer(f)writer.writerow([os.path.join(out_dir,dd,zz,mm),zz])## 创建类别label标号文件
labels= []
for label in sorted(os.listdir(video_path)):labels.append(label)
label2index = {label: index for index, label in enumerate(sorted(set(labels)))}
label_file = os.path.join(out_dir, str(len(os.listdir(video_path))) + 'class_labels.txt')
with open(label_file, 'w') as f:for id, label in enumerate(sorted(label2index)):f.writelines(str(id) + ' ' + label +'\n')#替换csv文件中类别名为数字
csv_file = os.path.join(out_dir,"csv")
def txt_read(files):txt_dict = {}fopen = open(files)for line in fopen.readlines():line = str(line).replace('\n','')txt_dict[line.split(' ',1)[1]] = line.split(' ',1)[0]      fopen.close()return txt_dict
txt_dict = txt_read(label_file)
print(txt_dict)for ii in os.listdir(csv_file):path1 = os.path.join(csv_file,ii)r = csv.reader(open(path1))lines = [l for l in r]for i in range(

相关文章:

  • 2024广东省职业技能大赛云计算赛项实战——容器云平台搭建
  • python watchdog 配置文件热更新
  • BP神经网络的反向传播(Back Propagation)
  • 方法区讲解
  • EasyExcel 导出批注信息
  • 【Go】十四、图形验证码、短信验证码、注册接口与redis的简单使用
  • 单片机练习题3
  • 每日优秀影视分享❗❗
  • WPF文本绑定显示格式StringFormat设置-特殊格式时间日期和多数据绑定
  • 原生dom操作快速写入html渲染(insertAdjacentHTML)
  • Cadence:Conformal系列形式验证工具
  • 深入解析Netty的Reactor模型及其实现:详解与代码示例
  • Pikachu靶场--XSS
  • excel数据透视
  • Ubuntu常见命令解释
  • 《Java8实战》-第四章读书笔记(引入流Stream)
  • ES6系统学习----从Apollo Client看解构赋值
  • golang中接口赋值与方法集
  • iOS 系统授权开发
  • js作用域和this的理解
  • Lucene解析 - 基本概念
  • magento2项目上线注意事项
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • vue 配置sass、scss全局变量
  • Vue小说阅读器(仿追书神器)
  • 区块链将重新定义世界
  • 如何胜任知名企业的商业数据分析师?
  • 数据可视化之 Sankey 桑基图的实现
  • 提升用户体验的利器——使用Vue-Occupy实现占位效果
  • 为什么要用IPython/Jupyter?
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • #07【面试问题整理】嵌入式软件工程师
  • #1015 : KMP算法
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • #pragam once 和 #ifndef 预编译头
  • #快捷键# 大学四年我常用的软件快捷键大全,教你成为电脑高手!!
  • $con= MySQL有关填空题_2015年计算机二级考试《MySQL》提高练习题(10)
  • $var=htmlencode(“‘);alert(‘2“); 的个人理解
  • %3cli%3e连接html页面,html+canvas实现屏幕截取
  • (¥1011)-(一千零一拾一元整)输出
  • (14)学习笔记:动手深度学习(Pytorch神经网络基础)
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (Redis使用系列) SpringBoot 中对应2.0.x版本的Redis配置 一
  • (SpringBoot)第二章:Spring创建和使用
  • (二) Windows 下 Sublime Text 3 安装离线插件 Anaconda
  • (更新)A股上市公司华证ESG评级得分稳健性校验ESG得分年均值中位数(2009-2023年.12)
  • (十七)Flask之大型项目目录结构示例【二扣蓝图】
  • (五)Python 垃圾回收机制
  • (原創) 如何解决make kernel时『clock skew detected』的warning? (OS) (Linux)
  • (自适应手机端)行业协会机构网站模板
  • .net 简单实现MD5
  • .NET/C# 在 64 位进程中读取 32 位进程重定向后的注册表
  • .net下的富文本编辑器FCKeditor的配置方法
  • .net与java建立WebService再互相调用
  • @AliasFor注解