当前位置: 首页 > news >正文

基于安卓的虫害识别软件设计--(2)模型性能可视化|混淆矩阵、热力图

1.混淆矩阵(Confusion Matrix)

1.1基础理论

(1)在机器学习、深度学习领域中,混淆矩阵常用于监督学习,匹配矩阵常用于无监督学习。主要用来比较分类结果和实际预测值。

(2)图中表达的含义:混淆矩阵的每一列代表了预测类别,每一行代表了数据的真实类别。

1.2 实现代码

import torch
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torchvision import transformsclasses = ['bai_xing_hua_jin_gui', 'beetle', 'chui_mian_jie', 'ci_e_ke', 'da_qing_ye_chan','dou_yuan_jing','fan_qie_qian_ye_ying_larva','fan_qie_qian_ye_ying_mature','hong_zhi_zhu','huang_zong_ke']# classes = ['白星化金龟', '甲虫', '吹绵蚧', '刺蛾科', '大青叶蝉','豆芫菁','番茄潜叶蛾幼虫','番茄潜叶蛾成虫','红蜘蛛','蝗总科']def predict_image(model, image_path, true_label):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")img = Image.open(image_path)val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])tensor_img = val_transform(img)tensor_img = tensor_img.to(device)tensor_img = tensor_img.unsqueeze(0)output = model(tensor_img)_, pred = output.max(1)pred_label = classes[pred.item()]return pred_label, true_labelif __name__ == '__main__':# 1. 加载模型model_path = r"/kaggle/input/mymodel3/resnet101_final.pth"model = torch.load(model_path)model.eval()# 2. 预测多张图片并记录真实标签和预测结果true_labels = []pred_labels = []images_dir = r"/kaggle/input/insects-new/validation"for label in os.listdir(images_dir):label_dir = os.path.join(images_dir, label)if not os.path.isdir(label_dir):continuefor img_name in os.listdir(label_dir):img_path = os.path.join(label_dir, img_name)true_labels.append(label)pred_label, _ = predict_image(model, img_path, label)pred_labels.append(pred_label)# 3. 计算混淆矩阵cm = confusion_matrix(true_labels, pred_labels, labels=classes)# 4. 计算归一化的混淆矩阵cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]# 5. 绘制混淆矩阵save_path = "/kaggle/working/confusion_matrix.png"plt.figure(figsize=(8, 6))sns.heatmap(cm_normalized, annot=True, cmap='Blues', xticklabels=classes, yticklabels=classes, fmt='.2f')plt.xlabel('预测标签')plt.ylabel('真实标签')plt.tight_layout()  # 自动调整子图参数plt.savefig(save_path)plt.show()

注意:以下数值需要和训练时的数值一样!


2.热力图

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torchcam.methods import GradCAMpp
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM
from torchvision import transforms
from torchcam.utils import overlay_mask# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)model = torch.load('/kaggle/input/mymodel3/resnet101_final.pth')
model = model.eval().to(device)cam_extractor = GradCAMpp(model)# 要与训练集保持一致
test_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomGrayscale(),transforms.ToTensor(),transforms.RandomErasing(),transforms.Normalize([0.460, 0.483, 0.396], [0.171, 0.150, 0.190])])# 载入目标图像
img_path = '/kaggle/input/insects-new/train/hong_zhi_zhu/13845.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
# 预测标签
pred_logits = model(input_tensor)
pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
# 矩阵热力图
plt.imshow(activation_map)
plt.show()
plt.savefig('/kaggle/working/activation_map.png')# 将原图重合
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)
result.save('/kaggle/working/result.png')result

 

相关文章:

  • 【制作100个unity游戏之27】使用unity复刻经典游戏《植物大战僵尸》,制作属于自己的植物大战僵尸随机版和杂交版6(附带项目源码)
  • x264 参考帧管理原理:b_ref_reorder 数组变量
  • Vue:路由管理vue-router
  • 信息标记形式 (XML, JSON, YAML)
  • DeepFace ——用于高级人脸识别算法探索与应用
  • 【Python】Python异步编程
  • FFmpeg 中 Filters 使用文档介绍
  • 纯网络的系统能否定级备案?
  • 易基因:RNA免疫共沉淀测序 (RIP-seq) 技术介绍
  • 【Java数据结构】详解Stack与Queue(二)
  • MTK 平台项目security boot 开启/关闭 及 系统签名流程
  • autowired注解底层实现代码
  • Ant Design Vue Pro流程分析记录
  • JMeter源码解析之SplashScreen.java
  • [每日一题]170:分糖果 II
  • 实现windows 窗体的自己画,网上摘抄的,学习了
  • [译] 怎样写一个基础的编译器
  • Android系统模拟器绘制实现概述
  • Java程序员幽默爆笑锦集
  • JS学习笔记——闭包
  • nginx 负载服务器优化
  • node学习系列之简单文件上传
  • Spark学习笔记之相关记录
  • Vim Clutch | 面向脚踏板编程……
  • 阿里云容器服务区块链解决方案全新升级 支持Hyperledger Fabric v1.1
  • 浮现式设计
  • 使用 QuickBI 搭建酷炫可视化分析
  • 使用 Xcode 的 Target 区分开发和生产环境
  • 想使用 MongoDB ,你应该了解这8个方面!
  • 用Node EJS写一个爬虫脚本每天定时给心爱的她发一封暖心邮件
  • 在 Chrome DevTools 中调试 JavaScript 入门
  • #我与Java虚拟机的故事#连载12:一本书带我深入Java领域
  • $(function(){})与(function($){....})(jQuery)的区别
  • (2024,Vision-LSTM,ViL,xLSTM,ViT,ViM,双向扫描)xLSTM 作为通用视觉骨干
  • (Python第六天)文件处理
  • (二刷)代码随想录第15天|层序遍历 226.翻转二叉树 101.对称二叉树2
  • (过滤器)Filter和(监听器)listener
  • (接口封装)
  • (三)c52学习之旅-点亮LED灯
  • (贪心 + 双指针) LeetCode 455. 分发饼干
  • (一) 初入MySQL 【认识和部署】
  • (转)微软牛津计划介绍——屌爆了的自然数据处理解决方案(人脸/语音识别,计算机视觉与语言理解)...
  • (转)为C# Windows服务添加安装程序
  • (轉)JSON.stringify 语法实例讲解
  • .NET CORE Aws S3 使用
  • .NET Micro Framework 4.2 beta 源码探析
  • .NET 简介:跨平台、开源、高性能的开发平台
  • .NET 中的轻量级线程安全
  • .Net6支持的操作系统版本(.net8已来,你还在用.netframework4.5吗)
  • .NET开源快速、强大、免费的电子表格组件
  • .xml 下拉列表_RecyclerView嵌套recyclerview实现二级下拉列表,包含自定义IOS对话框...
  • @ 代码随想录算法训练营第8周(C语言)|Day53(动态规划)
  • @KafkaListener注解详解(一)| 常用参数详解
  • @vue/cli 3.x+引入jQuery
  • []AT 指令 收发短信和GPRS上网 SIM508/548