当前位置: 首页 > news >正文

李沐d2l(十一)--目标检测

文章目录

    • 一、概念
    • 二、代码
      • 1 获取待检测图片
      • 2 定义两种边缘框表示之间的转换函数
      • 3 定义图像中狗和猫的边界框
      • 4 画出边界框
    • 三、目标检测数据集
      • 1 下载数据集
      • 2 读取数据集
      • 3 自定义Dataset实例
      • 4 返回训练集合测试集加载器实例
      • 5 读取一个小批量
      • 6 展示检测结果

一、概念

图片分类是给出指定图片,判断里面的主体类别,往往一张图片只有一个主体。而目标检测是要把含有多个主体的图片,识别出用户感兴趣的主体,一般通过方框(边缘)标注。

边缘框

一个边缘框可以通过4个数字定义

  1. 左上x,左上y,右下x,右下y
  2. 左上x,左上y,宽,高

二、代码

1 获取待检测图片

d2l.set_figsize()
img = d2l.plt.imread("catdog.jpg")
d2l.plt.imshow(img)
d2l.plt.show()

image-20220902084107627

2 定义两种边缘框表示之间的转换函数

def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes

3 定义图像中狗和猫的边界框

dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]

boxes = torch.tensor((dog_bbox, cat_bbox))

4 画出边界框

def bbox_to_rect(bbox, color):
    return d2l.plt.Rectangle(xy=(bbox[0], bbox[1]), width=bbox[2] - bbox[0],
                             height=bbox[3] - bbox[1], fill=False,
                             edgecolor=color, linewidth=2)

fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox, 'red'))
d2l.plt.show()

image-20220902085013350

三、目标检测数据集

1 下载数据集

import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')

2 读取数据集

def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签。"""
    data_dir = d2l.download_extract('banana-detection')
    csv_fname = os.path.join(data_dir,
                             'bananas_train' if is_train else 'bananas_val',
                             'label.csv')
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        images.append(
            torchvision.io.read_image(
                os.path.join(data_dir,
                             'bananas_train' if is_train else 'bananas_val',
                             'images', f'{img_name}')))
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1) / 256

3 自定义Dataset实例

class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集。"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (
            f' training examples' if is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)

4 返回训练集合测试集加载器实例

def load_data_bananas(batch_size):
    """加载香蕉检测数据集。"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter

5 读取一个小批量

batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))

6 展示检测结果

imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
d2l.plt.show()

image-20220902092345984

相关文章:

  • 美国上周初请人数23.2万人是两个月最低水平 美联储加息75基点稳了
  • 【技术美术知识储备】图形渲染管线1.0-基本概念CPU负责的应用阶段
  • 教你如何使用关键词获取淘宝和天猫的商品信息
  • lodash笔记(语言篇)
  • elasticsearch 6.3.2 集群配置
  • 操作系统复习:线程
  • Web3社交基础设施SBT
  • SAP-ABAP-SELECT语法SQL语法详解
  • 化妆品怎么在百度百科上创建词条,品牌上百度百科的条件和操作
  • DOM--事件
  • 这12款idea插件,能让你代码飞起来
  • 基于springboot+vue的新生宿舍管理系统 elementui
  • spring 事务的传播行为
  • 再写一遍的网络流
  • Linux之 如何查看文件是`硬链接`还是`软链接`
  • 自己简单写的 事件订阅机制
  • 【Under-the-hood-ReactJS-Part0】React源码解读
  • Centos6.8 使用rpm安装mysql5.7
  • classpath对获取配置文件的影响
  • git 常用命令
  • jquery cookie
  • Linux中的硬链接与软链接
  • ucore操作系统实验笔记 - 重新理解中断
  • 第13期 DApp 榜单 :来,吃我这波安利
  • 面试遇到的一些题
  • 算法系列——算法入门之递归分而治之思想的实现
  • 延迟脚本的方式
  • 硬币翻转问题,区间操作
  • SAP CRM里Lead通过工作流自动创建Opportunity的原理讲解 ...
  • 分布式关系型数据库服务 DRDS 支持显示的 Prepare 及逻辑库锁功能等多项能力 ...
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • # Java NIO(一)FileChannel
  • #AngularJS#$sce.trustAsResourceUrl
  • (1/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (10)工业界推荐系统-小红书推荐场景及内部实践【排序模型的特征】
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第2节(共同的基类)
  • (html5)在移动端input输入搜索项后 输入法下面为什么不想百度那样出现前往? 而我的出现的是换行...
  • (二)c52学习之旅-简单了解单片机
  • (附源码)计算机毕业设计SSM智慧停车系统
  • (一)基于IDEA的JAVA基础10
  • (转) 深度模型优化性能 调参
  • (转)EXC_BREAKPOINT僵尸错误
  • * 论文笔记 【Wide Deep Learning for Recommender Systems】
  • *1 计算机基础和操作系统基础及几大协议
  • . ./ bash dash source 这五种执行shell脚本方式 区别
  • .h头文件 .lib动态链接库文件 .dll 动态链接库
  • .NET Core6.0 MVC+layui+SqlSugar 简单增删改查
  • .net mvc 获取url中controller和action
  • .Net6 Api Swagger配置
  • .net访问oracle数据库性能问题
  • .vue文件怎么使用_我在项目中是这样配置Vue的
  • @entity 不限字节长度的类型_一文读懂Redis常见对象类型的底层数据结构
  • [APUE]进程关系(下)
  • [ChromeApp]指南!让你的谷歌浏览器好用十倍!
  • [Flex] PopUpButton系列 —— 控制弹出菜单的透明度、可用、可选择状态