当前位置: 首页 > news >正文

YOLOv8由pt文件中读取模型信息

Pytorch的pt模型文件中保存了许多模型信息,如模型结构、模型参数、任务类型、批次、数据集等
在先前的YOLOv8实验中,博主发现YOLOv8在预测时并不需要指定任务类型,因为这些信息便保存在pt模型中,那么,今天我们便来看看,其到底是如何加载这些参数的。

我们首先对pt文件进行一个简单介绍:

pt文格式

pt格式文件是PyTorch中用于保存张量数据的文件格式。与pth文件类似,pt文件也常用于模型的保存和加载,但更侧重于保存单个张量或一组张量数据。通过pt文件,我们可以方便地将张量数据持久化,并在需要时重新加载使用。

张量(Tensor)是PyTorch中的核心数据结构,用于表示多维数组。在深度学习中,张量常用于存储模型的参数、输入数据、中间结果等。因此,掌握pt文件的保存和加载方法对于PyTorch的使用者来说至关重要。

pt文件与pth的区别

pt.pth都是PyTorch模型文件的扩展名,但是它们的区别在于.pt文件是保存整个PyTorch模型的,而.pth文件只保存模型的参数。(其实现在似乎并没有区别了)
因此,如果要加载一个,pth文件,需要先定义模型的结构,然后再加载参数;而如果要加载一个,pt文件,则可以直接加载整个模型。

如何保存pt格式文件

PyTorch中,我们可以使用torch.save()函数将张量数据保存到pt文件中。

下面是一个简单的示例:

import torch
# 创建一个张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将张量保存到pt文件中
torch.save(tensor, 'tensor.pt')

在上面的代码中,我们首先创建了一个二维张量tensor,然后使用torch.save()函数将其保存到名为tensor.pt的文件中。保存的文件将包含张量的数据和元数据,以便在加载时能够准确地恢复张量的结构和内容。

除了保存单个张量外,我们还可以保存多个张量到一个pt文件中。这可以通过将多个张量放入一个字典或列表中,然后将整个字典或列表保存到文件中实现。

例如:

# 创建多个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([[4, 5], [6, 7]])# 将张量放入字典中
tensors_dict = {'tensor1': tensor1, 'tensor2': tensor2}# 将字典保存到pt文件中
torch.save(tensors_dict, 'tensors_dict.pt')

如何加载pt格式文件

加载pt文件同样使用torch.load()函数。

下面是一个加载pt文件的示例:

# 加载单个张量的pt文件
loaded_tensor = torch.load('tensor.pt')
print(loaded_tensor)# 加载包含多个张量的字典的pt文件
loaded_dict = torch.load('tensors_dict.pt')
print(loaded_dict['tensor1'])
print(loaded_dict['tensor2'])

在加载单个张量的pt文件时,我们直接调用torch.load()函数并传入文件名即可。加载得到的loaded_tensor将是一个与原始张量结构和内容相同的张量对象。
当加载包含多个张量的字典的pt文件时,我们同样使用torch.load()函数。加载得到的loaded_dict将是一个字典对象,其中包含了我们在保存时放入的所有张量。我们可以通过字典的键来访问这些张量。

强烈建议只保存模型参数,而非保存整个网络。PyTorch 官方也是这么建议的。

torch.save(net.state_dict(),path2)#只保留模型参数

(只保存模型参数)是官方推荐的方法,运行速度快,且占空间较小。需要注意的是 net.state_dict() 是将网络参数保存为字典形式(OrderedDict)load_state_dict() 加载的并不是网络参数的pth文件,而是字典。

pt文件保存神经网络

在评估时,记住一定要使用model.eval()来固定dropout和归一化层,否则每次推理会生成不同的结果。

import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):  # 神经网络部分用你自己的def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)  # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)self.conv3 = nn.Conv2d(64, 128, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(6272, 128)  # 6272=128*7*7self.fc2 = nn.Linear(128, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = self.conv3(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)self.output = F.log_softmax(x, dim=1)out1 = xreturn self.output,out1def predict_mine():model=Net()model.load_state_dict(torch.load("model.pt"))print(model)images=torch.rand((1,1,64,64))x=model(images)print(x)
def torch_script_save():model=Net()
if __name__ == '__main__':save_model()predict_mine()

在这里插入图片描述

可以看到,我们可以通过pt文件读取出来下面的信息:

在这里插入图片描述

同时,我们也看到,我们虽然可以使用pt文件保存模型结构,但我们在推理时,依旧需要我们能够生成Net对象才能加载其数据,这其实很不方便,那么,有什么办法可以真正的将模型结构保存进去,让我们在推理过程中不需要再定义相关的类与对象呢,先前博主所使用的ONNX便是其中的一种,但它其实是另一种文件结构了,pt文件真的就不能摆脱环境吗,答案是否定的,TorchScript模型便解决了这个问题。

TorchScript模型

事实上,PyTorch提供了两种主要的模型保存和加载机制,一种是基于Python的序列化,另一种是TorchScript

普通的PyTorch模型(基于Python的序列化):

  • 保存: 使用torch.save(model.state_dict(), 'model_path.pth'),它保存了模型的权重和参数,但不保存模型的结构。(当然也是可以保存的,但我们需要处理一下才能用,比如定义好Net类)
  • 加载: 首先,您需要有模型的类定义。 创建该类的一个实例。 使用model.load_state_dict(torch.load('model_path.pth'))来加载权重。
    特点:
  • 需要Python环境和模型的原始代码来加载和运行模型。 保存的文件是Python特定的,并且依赖于特定的类结构。 主要用于继续训练或在Python环境中进行推断。

TorchScript模型:

TorchScript是PyTorch的一个子集,它创建了一个可以独立于Python运行的序列化模型。 生成方法:

  • Tracing: 使用torch.jit.trace方法。这涉及到通过模型运行一个输入示例,从而跟踪模型的执行路径。
  • Scripting: 使用torch.jit.script方法。这转化Python代码到TorchScript,允许更复杂的模型和控制流。
  • 保存: 使用torch.jit.save(traced_model, 'model_path.pt')
  • 加载: 使用torch.jit.load(‘model_path.pt’)。注意,加载不需要原始的模型类定义。

特点:

  • 可以在没有Python运行时的环境中运行,如C++
  • 提供了一种方法,将模型从Python转移到其他平台或部署环境。
  • 包含模型的完整定义,包括结构、权重和参数。

Tracing方法:

example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_model.pt')

Scripting方法:

scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')

加载模型:

loaded_model = torch.jit.load('model_path.pt')

例程:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()# Tracing
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_simple_model.pt')# Scripting
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_simple_model.pt')# 加载模型
loaded_model = torch.jit.load('traced_simple_model.pt')

我们采用TorchScript结构去执行先前的Net

def torch_script_save():model=Net()example_input =torch.rand((1,1,64,64))traced_model = torch.jit.trace(model, example_input)torch.jit.save(traced_model, 'traced_simple_model.pt')# Scriptingscripted_model = torch.jit.script(model)torch.jit.save(scripted_model, 'scripted_simple_model.pt')
def predict_script():model1=torch.jit.load("traced_simple_model.pt")image =torch.rand((1,1,64,64))print(model1)model1.eval()x=model1(image)print(x)model2=torch.jit.load("scripted_simple_model.pt")image =torch.rand((1,1,64,64))print(model2)model2.eval()x=model2(image)print(x)

在这里插入图片描述

yaml文件内容如下:

{'nc': 1000, 
'scales': {'n': [0.33, 0.25, 1024], 's': [0.33, 0.5, 1024], 'm': [0.67, 0.75, 1024], 'l': [1.0, 1.0, 1024], 'x': [1.0, 1.25, 1024]}, 
'backbone': [[-1, 1, 'Conv', [64, 3, 2]], [-1, 1, 'Conv', [128, 3, 2]], [-1, 3, 'C2f', [128, True]], [-1, 1, 'Conv', [256, 3, 2]], [-1, 6, 'C2f', [256, True]], [-1, 1, 'Conv', [512, 3, 2]], [-1, 6, 'C2f', [512, True]], [-1, 1, 'Conv', [1024, 3, 2]], [-1, 3, 'C2f', [1024, True]]],'head': [[-1, 1, 'Classify', ['nc']]], 'scale': 'n','yaml_file': 'yolov8n-cls.yaml', 'ch': 3}

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • MongoDB 未授权访问漏洞
  • c# 逻辑运算符和条件运算符
  • Spring Boot 参数校验 Validation 使用
  • 反其道而行的SAP商业AI
  • Linux搭建SVN服务器
  • 无法启动此程序,因为计算机中丢失dll的多种解决方法,3分钟修复(dll修复工具详细教程)
  • react中的装饰器
  • FPGA开发——在Quartus中实现对IP核的PLL调用
  • ⌈ 传知代码 ⌋ 基于矩阵乘积态的生成模型
  • HarmonyOS笔记3:从网络数据接口API获取数据
  • 人工智能深度学习系列—深入解析:均方误差损失(MSE Loss)在深度学习中的应用与实践
  • ELK对业务日志进行收集
  • NodeJS 依赖下载及切换下载源
  • 29.Labview界面设计(下篇) --- 自定义控件库、界面布局与外观设计
  • (MTK)java文件添加简单接口并配置相应的SELinux avc 权限笔记2
  • 【Leetcode】104. 二叉树的最大深度
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • canvas 绘制双线技巧
  • ES10 特性的完整指南
  • Java 9 被无情抛弃,Java 8 直接升级到 Java 10!!
  • Java精华积累:初学者都应该搞懂的问题
  • Logstash 参考指南(目录)
  • Netty 框架总结「ChannelHandler 及 EventLoop」
  • NSTimer学习笔记
  • quasar-framework cnodejs社区
  • REST架构的思考
  • SOFAMosn配置模型
  • 蓝海存储开关机注意事项总结
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 算法---两个栈实现一个队列
  • 微服务框架lagom
  • 异常机制详解
  • 译自由幺半群
  • ‌前端列表展示1000条大量数据时,后端通常需要进行一定的处理。‌
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • #systemverilog# 之 event region 和 timeslot 仿真调度(十)高层次视角看仿真调度事件的发生
  • (android 地图实战开发)3 在地图上显示当前位置和自定义银行位置
  • (void) (_x == _y)的作用
  • (层次遍历)104. 二叉树的最大深度
  • (接上一篇)前端弄一个变量实现点击次数在前端页面实时更新
  • (轉)JSON.stringify 语法实例讲解
  • .form文件_SSM框架文件上传篇
  • .net core webapi 大文件上传到wwwroot文件夹
  • .net framework4与其client profile版本的区别
  • .NET 快速重构概要1
  • .net 无限分类
  • .NET/C# 中设置当发生某个特定异常时进入断点(不借助 Visual Studio 的纯代码实现)
  • :如何用SQL脚本保存存储过程返回的结果集
  • @GetMapping和@RequestMapping的区别
  • [2016.7 day.5] T2
  • [android] 切换界面的通用处理
  • [ARM]ldr 和 adr 伪指令的区别
  • [BT]小迪安全2023学习笔记(第15天:PHP开发-登录验证)
  • [C++]运行时,如何确保一个对象是只读的
  • [Contiki系列论文之2]WSN的自适应通信架构