李沐d2l(十一)--目标检测
文章目录
- 一、概念
- 二、代码
- 1 获取待检测图片
- 2 定义两种边缘框表示之间的转换函数
- 3 定义图像中狗和猫的边界框
- 4 画出边界框
- 三、目标检测数据集
- 1 下载数据集
- 2 读取数据集
- 3 自定义Dataset实例
- 4 返回训练集合测试集加载器实例
- 5 读取一个小批量
- 6 展示检测结果
一、概念
图片分类是给出指定图片,判断里面的主体类别,往往一张图片只有一个主体。而目标检测是要把含有多个主体的图片,识别出用户感兴趣的主体,一般通过方框(边缘)标注。
边缘框
一个边缘框可以通过4个数字定义
- 左上x,左上y,右下x,右下y
- 左上x,左上y,宽,高
二、代码
1 获取待检测图片
d2l.set_figsize()
img = d2l.plt.imread("catdog.jpg")
d2l.plt.imshow(img)
d2l.plt.show()
2 定义两种边缘框表示之间的转换函数
def box_corner_to_center(boxes):
"""从(左上,右下)转换到(中间,宽度,高度)"""
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
boxes = torch.stack((cx, cy, w, h), axis=-1)
return boxes
def box_center_to_corner(boxes):
"""从(中间,宽度,高度)转换到(左上,右下)"""
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
boxes = torch.stack((x1, y1, x2, y2), axis=-1)
return boxes
3 定义图像中狗和猫的边界框
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
boxes = torch.tensor((dog_bbox, cat_bbox))
4 画出边界框
def bbox_to_rect(bbox, color):
return d2l.plt.Rectangle(xy=(bbox[0], bbox[1]), width=bbox[2] - bbox[0],
height=bbox[3] - bbox[1], fill=False,
edgecolor=color, linewidth=2)
fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox, 'red'))
d2l.plt.show()
三、目标检测数据集
1 下载数据集
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
d2l.DATA_HUB['banana-detection'] = (
d2l.DATA_URL + 'banana-detection.zip',
'5de26c8fce5ccdea9f91267273464dc968d20d72')
2 读取数据集
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签。"""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(
torchvision.io.read_image(
os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'images', f'{img_name}')))
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
3 自定义Dataset实例
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集。"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (
f' training examples' if is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
4 返回训练集合测试集加载器实例
def load_data_bananas(batch_size):
"""加载香蕉检测数据集。"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
5 读取一个小批量
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
6 展示检测结果
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
d2l.plt.show()