当前位置: 首页 > news >正文

PyTorch计算机视觉入门:测试模型与评估,对单帧图片进行推理

在完成模型的训练之后,对模型进行测试与评估是至关重要的一步,它能帮助我们理解模型在未知数据上的泛化能力。本篇指南将带您了解如何使用PyTorch进行模型测试,并对测试结果进行分析。我们将基于之前训练好的模型,演示如何加载数据、进行预测、计算指标以及可视化结果。

准备工作

假设您已经有一个训练好的模型,保存在.pth文件中,以及一个用于测试的自定义数据集。我们将继续使用前文提到的自定义数据集CustomDataset类,并引入一些新的概念和代码。

加载测试数据集

与训练过程类似,首先需要加载测试数据集,并对其进行适当的预处理。确保您的测试集遵循与训练集相同的数据结构和预处理步骤。

test_dataset = CustomImageDataset(data_path="./data/", model= "test", transform = transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

测试模型

训练完成后,使用测试数据集来评估模型的性能

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target)# max 函数返回两个值,一个是是数值,一个是indexpred = output.max(1, keepdim=True)[1] # 找到概率最大的下标 correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
Test set: Average loss: 0.0018, Accuracy: 9671/10000 (97%)

单帧图片进行测试

# test single image 
img = Image.open("./data/data_test/1.jpg")
img_t = transform(img)
img_t = img_t.unsqueeze(0)  # 变为[1, 1, 28, 28]
img_t = img_t.to(device)
model.eval()
output = model(img_t)
_, predicted_class = torch.max(output, 1)
print(predicted_class)
tensor([2], device='cuda:0')

通过以上步骤,我们可以全面地评估和分析PyTorch模型在计算机视觉任务中的表现,从而确保模型在实际应用中的有效性和可靠性。

关注我的公众号Ai fighting, 第一时间获取更新内容。

相关文章:

  • 【SpringBoot】SpringBoot:构建实时聊天应用
  • Java数据结构与算法(完全背包)
  • Qt 实战(4)信号与槽 | 4.3、信号连接信号
  • 0118__C语言——float.h文件
  • 使用Spyder进行Python编程和代码调试
  • Qt项目天气预报(1) - ui界面搭建
  • 集成学习方法:Bagging与Boosting的应用与优势
  • C++中的结构体——结构体中const的使用场景
  • express入门03增删改查
  • Java 代理模式
  • C语言---------深入理解指针
  • react:handleEdit={() => handleEdit(user)} 和 handleEdit={handleEdit(user)}有啥区别
  • MFC socket编程-服务端和客户端流程
  • Vue43-单文件组件
  • 22.1 正则表达式-定义正则表达式、正则语法
  • @jsonView过滤属性
  • JavaScript-Array类型
  • js对象的深浅拷贝
  • mongodb--安装和初步使用教程
  • Next.js之基础概念(二)
  • php面试题 汇集2
  • Redash本地开发环境搭建
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 算法---两个栈实现一个队列
  • ​​​【收录 Hello 算法】10.4 哈希优化策略
  • ​LeetCode解法汇总2583. 二叉树中的第 K 大层和
  • ###51单片机学习(2)-----如何通过C语言运用延时函数设计LED流水灯
  • #Linux(权限管理)
  • #大学#套接字
  • #多叉树深度遍历_结合深度学习的视频编码方法--帧内预测
  • (1)常见O(n^2)排序算法解析
  • (Oracle)SQL优化技巧(一):分页查询
  • (ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY)讲解
  • (算法)求1到1亿间的质数或素数
  • (一)使用Mybatis实现在student数据库中插入一个学生信息
  • .chm格式文件如何阅读
  • .NET CORE Aws S3 使用
  • .NET Framework杂记
  • .net on S60 ---- Net60 1.1发布 支持VS2008以及新的特性
  • .NET 设计模式—简单工厂(Simple Factory Pattern)
  • .NET 中让 Task 支持带超时的异步等待
  • .net 逐行读取大文本文件_如何使用 Java 灵活读取 Excel 内容 ?
  • .NET大文件上传知识整理
  • .net实现客户区延伸至至非客户区
  • .net实现头像缩放截取功能 -----转载自accp教程网
  • /etc/fstab和/etc/mtab的区别
  • @ResponseBody
  • @Transactional事务注解内含乾坤?
  • [2009][note]构成理想导体超材料的有源THz欺骗表面等离子激元开关——
  • [20170713] 无法访问SQL Server
  • [4.9福建四校联考]
  • [Android] Binder 里的 Service 和 Interface 分别是什么
  • [AutoSAR系列] 1.3 AutoSar 架构
  • [C#]winform部署yolov9的onnx模型
  • [C++参考]拷贝构造函数的参数必须是引用类型