当前位置: 首页 > news >正文

pytorch——保存‘类别名与类别数量’到权值文件中

前言

不知道大家有没有像我一样,每换一次不一样的模型,就要输入不同的num_classes和name_classes,反正我是很头疼诶,尤其是项目里面不止一个模型的时候,更新的时候看着就很头疼,然后就想着直接输入模型权值文件的path该多好,然后我就搞起来了。

在自己的类中加入想要加入数据信息

class your_nets(nn.Module):def __init__(self, num_classes = 21,name_classes=None):super(your_nets, self).__init__()self.num_classes = num_classesself.name_classes = name_classes

训练过程之保存文件

      
model = your_nets(num_classes=num_classes, name_classes=name_classes)save_dict = {'state_dict': model.state_dict(),'num_classes': model.num_classes,'name_classes': model.name_classes}torch.save(save_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

使用 

model = get_nets_class(model_path=model_path)class get_nets_class(object):def __init__(self ,**kwargs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')load_dict  = torch.load(self.model_path, map_location=device)state_dict =load_dict['state_dict']num_classes = load_dict['num_classes']name_classes = load_dict['name_classes']if num_classes is not None and name_classes is not None:self.num_classes =num_classesself.name_classes = name_classesself.net = your_nets(num_classes=self.num_classes,name_classes=name_classe)self.net.load_state_dict(state_dict)else:self.net = your_nets(num_classes=self.num_classes, backbone=self.backbone)self.net.load_state_dict(load_dict)self.net = self.net.eval()def predict(self,image,name_classes,object_list):#你的预处理操作,没有就忽略image_data = preprocess(image)with torch.no_grad():# 推理pr = self.net(images)[0]# softmax 得出概率 pr.permute(1, 2, 0), dim=-1为我自己的操作,没有请忽略pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()#你的后处理操作,没有就忽略pr = postprocess(pr)#这一步与object_list有关 object_list是你想要模型去预测的内容# 例如你训练了识别cat、dog、pig、person的类别 那么你想只识别人,那么就object_list=['person'] if object_list is not None:model_object_list = [name_classes.index(i) for i in object_list if i in name_classes]temp_list = [i for i in range(len(name_classes))]remove_list = [i for i in temp_list if i not in model_object_list]for i in remove_list:pr[pr==i] = 0retuen pr

我是觉得已经很详细了,大家要是不懂可以再问,我可以慢慢改进,每个人的写法都不一样 。

欢迎大家点赞加收藏哟~

相关文章:

  • 华为mpls vpn hubspoke经典案例组网
  • Linux的7个运行级别
  • No matching client found for package name ‘com.unity3d.player‘
  • docker部署自己的网站wordpress
  • [Vue3]父子组件相互传值数据同步
  • 【linux】通过脚本、系统服务监控开机时间和 cpu 温度
  • wins 安装 tensorflow keras
  • HuTool工具使用(JSONUtil+JSONObject+JSONArray)
  • 3593 蓝桥杯 查找最大元素 简单
  • Leetcode—42. 接雨水【困难】
  • 项目02《游戏-08-开发》Unity3D
  • HarmonyOS鸿蒙ArkTS证件照生成模板(适合二次开发,全套源码版)
  • 面试复盘6——后端开发
  • 进程控制(Linux)
  • 【蓝桥杯冲冲冲】[NOIP2003 普及组] 栈
  • 9月CHINA-PUB-OPENDAY技术沙龙——IPHONE
  • @jsonView过滤属性
  • [case10]使用RSQL实现端到端的动态查询
  • create-react-app项目添加less配置
  • go append函数以及写入
  • javascript 总结(常用工具类的封装)
  • JS变量作用域
  • Spark学习笔记之相关记录
  • Vue2.x学习三:事件处理生命周期钩子
  • Xmanager 远程桌面 CentOS 7
  • 腾讯优测优分享 | 你是否体验过Android手机插入耳机后仍外放的尴尬?
  • 原生js练习题---第五课
  • python最赚钱的4个方向,你最心动的是哪个?
  • 如何正确理解,内页权重高于首页?
  • 新海诚画集[秒速5センチメートル:樱花抄·春]
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • #使用清华镜像源 安装/更新 指定版本tensorflow
  • (20)目标检测算法之YOLOv5计算预选框、详解anchor计算
  • (android 地图实战开发)3 在地图上显示当前位置和自定义银行位置
  • (SpringBoot)第七章:SpringBoot日志文件
  • (安全基本功)磁盘MBR,分区表,活动分区,引导扇区。。。详解与区别
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (亲测)设​置​m​y​e​c​l​i​p​s​e​打​开​默​认​工​作​空​间...
  • (十五)Flask覆写wsgi_app函数实现自定义中间件
  • (十一)JAVA springboot ssm b2b2c多用户商城系统源码:服务网关Zuul高级篇
  • (推荐)叮当——中文语音对话机器人
  • (一)硬件制作--从零开始自制linux掌上电脑(F1C200S) <嵌入式项目>
  • .Net Core 中间件验签
  • .net mvc部分视图
  • .net redis定时_一场由fork引发的超时,让我们重新探讨了Redis的抖动问题
  • .NET 编写一个可以异步等待循环中任何一个部分的 Awaiter
  • .Net 垃圾回收机制原理(二)
  • .NET/C# 阻止屏幕关闭,阻止系统进入睡眠状态
  • .NET:自动将请求参数绑定到ASPX、ASHX和MVC(菜鸟必看)
  • .vimrc php,修改home目录下的.vimrc文件,vim配置php高亮显示
  • /etc/sudoer文件配置简析
  • /proc/interrupts 和 /proc/stat 查看中断的情况
  • @Transactional注解下,循环取序列的值,但得到的值都相同的问题