当前位置: 首页 > news >正文

实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):"""Input:x: all points, [B, C, N]k: k nearest points of each pointReturn:idx: grouped points index, [B, N, k]"""inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):"""Input:nsample: max sample number in local regionxyz: all points, [B, N, C]new_xyz: query points, [B, S, C]Return:group_idx: grouped points index, [B, S, nsample]"""sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch# Make sure your CUDA is available.
assert torch.cuda.is_available()from knn_cuda import KNN
"""
if transpose_mode is True, ref   is Tensor [bs x nr x dim]query is Tensor [bs x nq x dim]return dist is Tensor [bs x nq x k]indx is Tensor [bs x nq x k]
elseref   is Tensor [bs x dim x nr]query is Tensor [bs x dim x nq]return dist is Tensor [bs x k x nq]indx is Tensor [bs x k x nq]
"""knn = KNN(k=10, transpose_mode=True)ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNNdef knn(x, k):inner = -2*torch.matmul(x.transpose(2, 1), x)xx = torch.sum(x**2, dim=1, keepdim=True)pairwise_distance = -xx - inner - xx.transpose(2, 1)idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)return idxdef square_distance(src, dst):B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef knn_point(nsample, xyz, new_xyz):sqrdists = square_distance(new_xyz, xyz)_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)return group_idx# Custom knn implementation
def test_knn(query, k, times):query = query.permute(0,2,1)start_time = time.time()  # Start timerfor i in range(times):indx = knn(query, k = k)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Custom knn_point implementation
def test_knn_point(ref, query, k, times):start_time = time.time()  # Start timerfor i in range(times):indx = knn_point(k, ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):knn = KNN(k=k, transpose_mode=True)start_time = time.time()  # Start timerfor i in range(times):dist, indx = knn(ref, query)end_time = time.time()  # End timerreturn end_time - start_time  # Return elapsed time# Main testing function
def test_knn_methods(ref, query, k, times):print("Test times: %d" % times)# Test custom knntime_knn = test_knn(query, k, times)print(f"knn      : {time_knn:.6f} seconds")# Test custom knn_pointtime_point = test_knn_point(ref, query, k, times)print(f"knn_point: {time_point:.6f} seconds")# Test knn_cudatime_cuda = test_knn_cuda(ref, query, k, times)print(f"knn_cuda : {time_cuda:.6f} seconds")if __name__ == '__main__':# Sample inputB, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinatesk = 24                            # Number of nearest neighborsref = torch.randn(B, N, C).cuda() # Reference points# Test above methodstimes_list = [1,2,3,10,50,100]for times in times_list:test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Android11 MTK 安装apk时进行密码验证
  • 在Unity环境中使用UTF-8编码
  • SQL COUNT() 函数深入解析
  • MapSet之二叉搜索树
  • InfiniBand (IB) 和 RDMA over Converged Ethernet (RoCE)
  • ARM基础知识---CPU---处理器
  • QT Creator在线安装包、离线包下载链接
  • Java并发:互斥锁,读写锁,Condition,StampedLock
  • 在Spring Boot中通过自定义注解、反射以及AOP(面向切面编程)
  • vite+vue3+typescript+elementPlus前端实现电子证书查询系统
  • RabbitMQ 基础架构流程 数据隔离 创建用户
  • Java高级Day38-网络编程作业
  • 如何打造高校实验室教学管理系统?Java SpringBoot助力,MySQL存储优化,2025届必备设计指南
  • 【Linux】Linux 管道:进程间通信的利器
  • 【微信小程序】搭建项目步骤 + 引入Tdesign UI
  • C++类中的特殊成员函数
  • Druid 在有赞的实践
  • Idea+maven+scala构建包并在spark on yarn 运行
  • JavaScript学习总结——原型
  • Java编程基础24——递归练习
  • js面向对象
  • js正则,这点儿就够用了
  • Linux下的乱码问题
  • node学习系列之简单文件上传
  • PV统计优化设计
  • Spring Cloud中负载均衡器概览
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 浅谈Kotlin实战篇之自定义View图片圆角简单应用(一)
  • 微信开放平台全网发布【失败】的几点排查方法
  • 小程序01:wepy框架整合iview webapp UI
  • 责任链模式的两种实现
  • media数据库操作,可以进行增删改查,实现回收站,隐私照片功能 SharedPreferences存储地址:
  • scrapy中间件源码分析及常用中间件大全
  • 阿里云移动端播放器高级功能介绍
  • 关于Android全面屏虚拟导航栏的适配总结
  • 正则表达式-基础知识Review
  • # C++之functional库用法整理
  • (6)添加vue-cookie
  • (ibm)Java 语言的 XPath API
  • (web自动化测试+python)1
  • (二)fiber的基本认识
  • (二)斐波那契Fabonacci函数
  • (附源码)基于SpringBoot和Vue的厨到家服务平台的设计与实现 毕业设计 063133
  • (回溯) LeetCode 78. 子集
  • (九)c52学习之旅-定时器
  • (论文阅读笔记)Network planning with deep reinforcement learning
  • (十)Flink Table API 和 SQL 基本概念
  • (一)Dubbo快速入门、介绍、使用
  • (转)3D模板阴影原理
  • (转)EOS中账户、钱包和密钥的关系
  • (转载)微软数据挖掘算法:Microsoft 时序算法(5)
  • .net core 控制台应用程序读取配置文件app.config
  • .Net 代码性能 - (1)
  • .net6 webapi log4net完整配置使用流程
  • .sdf和.msp文件读取