当前位置: 首页 > news >正文

基于Python的机器学习系列(11):K-Nearest Neighbors

简介

        K-Nearest Neighbors (KNN) 算法是所有监督学习算法中最简单、直观的之一。其基本思想是通过计算新数据点到所有训练数据点的距离,找到距离最近的 K 个数据点(即 K 个邻居),然后根据这 K 个邻居的多数类别来决定新数据点的类别。

        例如,给定一个红色的交叉点 X,我们只需要获取其周围最近的邻居,并根据这些邻居的多数类别来为 X 分类。

算法实现步骤

  1. 准备数据

    • 获取训练集和测试集,将数据整理成适当的格式。
    • 数据标准化,以加速算法收敛。
    • 分割训练集和测试集。
  2. 计算点之间的距离

    编写函数来计算测试数据与所有训练数据之间的成对距离。
  3. 找到最近的邻居

    对距离进行排序,选取前 K 个最近的邻居。
  4. 根据多数类别进行分类

    统计最近 K 个邻居的类别,并选取出现频率最高的类别作为预测结果。

实践

1. 准备数据

        我们将从一个二维数据集开始,其中包含 4 个不同的类别。首先,我们将数据集分割成训练集和测试集,并对数据进行标准化处理。

from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 生成数据
X, y = make_blobs(n_samples=300, centers=4, random_state=0, cluster_std=1.0)# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)# 标准化数据
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

2. 计算点之间的距离

        接下来,我们将编写一个函数来计算测试集和训练集之间的成对距离。

def find_distance(X_train, X_test):dist = X_test[:, np.newaxis, :] - X_train[np.newaxis, :, :]sq_dist = dist ** 2summed_dist = sq_dist.sum(axis=2)return np.sqrt(summed_dist)

3. 找到最近的邻居

        我们将利用上一步计算的距离矩阵,找到每个测试数据点的 K 个最近邻居。

def find_neighbors(X_train, X_test, k=3):dist = find_distance(X_train, X_test)neighbors_ix = np.argsort(dist)[:, 0:k]return neighbors_ix

4. 根据多数类别进行分类

        最后,我们统计最近 K 个邻居的类别,并选取出现频率最高的类别作为预测结果。

def get_most_common(y):return np.bincount(y).argmax()def predict(X_train, X_test, y_train, k=3):neighbors_ix = find_neighbors(X_train, X_test, k)pred = np.zeros(X_test.shape[0])for ix, y in enumerate(y_train[neighbors_ix]):pred[ix] = get_most_common(y)return predyhat = predict(X_train, X_test, y_train, k=3)

5. 验证模型性能

        我们将使用准确率、平均精度得分和分类报告来评估 KNN 模型的性能。

from sklearn.metrics import average_precision_score, classification_report
from sklearn.preprocessing import label_binarizen_classes = len(np.unique(y_test))print("Accuracy: ", np.sum(yhat == y_test)/len(y_test))y_test_binarized = label_binarize(y_test, classes=[0, 1, 2, 3])
yhat_binarized = label_binarize(yhat, classes=[0, 1, 2, 3])for i in range(n_classes):class_score = average_precision_score(y_test_binarized[:, i], yhat_binarized[:, i])print(f"Class {i} score: ", class_score)print("Classification report: ")
print(classification_report(y_test, yhat))

使用Scikit-Learn实现KNN

        Scikit-Learn 提供了更为简便的实现方式,通过KNeighborsClassifier 可以轻松地创建和调参 KNN 模型。

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCVmodel = KNeighborsClassifier()
param_grid = {"n_neighbors": np.arange(2, 10)}cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
grid = GridSearchCV(model, param_grid=param_grid, cv=cv, refit=True)
grid.fit(X_train, y_train)print(f"The best parameters are {grid.best_params_} with a score of {grid.best_score_:.2f}")yhat = grid.predict(X_test)print("Classification report: ")
print(classification_report(y_test, yhat))

何时使用KNN

        KNN 算法的实现相对简单,在一些简单的分类问题上表现良好。然而,它的劣势也十分明显:

  • 随着特征数量的增加,计算开销急剧上升。KNN 需要为每个输入点计算到所有其他点的距离,并排序,这在特征数多时代价极高。
  • 无法处理类别型特征,因为难以为类别型特征制定合适的距离公式。
  • 调整邻居数目(K)的过程耗时较长。

加深理解

  1. 比较 Naive Bayes、Logistic Regression 和 K-Nearest Neighbors 三种算法

    • Naive Bayes:假设特征之间相互独立,适用于高维数据和文本分类任务,计算速度快,但假设过于简单。
    • Logistic Regression:适用于线性可分数据,有明确的概率输出,适合特征数量适中的数据集。
    • K-Nearest Neighbors:直观简单,适合少量特征和数据分布规则的情况,但计算成本高,特征多时效果差。
  2. 欧几里得距离失效的情况

    当特征具有不同的量纲或尺度时,欧几里得距离可能会失效,因为它会受到大尺度特征的主导。此时需要进行特征标准化或使用其他距离度量。
  3. 分类平局时的处理方法

    若出现平局情况,可以选择:
    • 随机选择一个类别
    • 使用权重距离,即离得越近的点权重越大
    • 减少 K 的值
  4. K-Nearest Neighbors 分类计算

    给定测试数据,使用 K=3 计算其属于类 0 或类 1 的可能性,并给出预测类别。

结语

        K-Nearest Neighbors 是一种简单但功能强大的分类算法,特别适合初学者学习机器学习的基本概念。然而,它的计算复杂度较高,随着数据量和特征维度的增加,可能并不是最优选择。因此,在实际应用中,需要根据数据集的具体情况选择合适的算法。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Spark2.x 入门:DStream 输出操作
  • 鹏哥C语言自定义笔记重点(29-)
  • Oracle问题笔记
  • 跟李沐学AI:语义分割
  • Leetcode-day30-动态规划-不同路径
  • STM32G474的HRTIM用作时基定时器
  • R语言统计分析——回归分析的改进措施
  • 【机器学习】YOLO 关闭控制台推理日志
  • 2024前端面试题-js篇
  • ffmpeg6.1集成Plus-OpenGL-Patch滤镜
  • Java二十三种设计模式-解释器模式(23/23)
  • web开发html前端使用javascript脚本库JsBarcode生成条形码(条码)
  • Vue的生命周期了解
  • 数学基础 -- 线性代数之行列式不变性推导
  • linux文本分析工具grep、sed和awk打印输出文本的单双奇偶行(grep也可以打印奇偶行)以及熟悉的ssh命令却有你不知道的一些用法
  • 【MySQL经典案例分析】 Waiting for table metadata lock
  • 【跃迁之路】【699天】程序员高效学习方法论探索系列(实验阶段456-2019.1.19)...
  • Django 博客开发教程 8 - 博客文章详情页
  • hadoop集群管理系统搭建规划说明
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • Linux快速配置 VIM 实现语法高亮 补全 缩进等功能
  • MySQL QA
  • PHP 程序员也能做的 Java 开发 30分钟使用 netty 轻松打造一个高性能 websocket 服务...
  • Swoft 源码剖析 - 代码自动更新机制
  • 阿里云容器服务区块链解决方案全新升级 支持Hyperledger Fabric v1.1
  • 笨办法学C 练习34:动态数组
  • 学习笔记TF060:图像语音结合,看图说话
  • Java性能优化之JVM GC(垃圾回收机制)
  • #includecmath
  • #考研#计算机文化知识1(局域网及网络互联)
  • $L^p$ 调和函数恒为零
  • (C++20) consteval立即函数
  • (day18) leetcode 204.计数质数
  • (rabbitmq的高级特性)消息可靠性
  • (solr系列:一)使用tomcat部署solr服务
  • (创新)基于VMD-CNN-BiLSTM的电力负荷预测—代码+数据
  • (附源码)ssm学生管理系统 毕业设计 141543
  • (转)C#开发微信门户及应用(1)--开始使用微信接口
  • (转)Scala的“=”符号简介
  • (转)为C# Windows服务添加安装程序
  • (转)详解PHP处理密码的几种方式
  • *2 echo、printf、mkdir命令的应用
  • .bat批处理(七):PC端从手机内复制文件到本地
  • .cn根服务器被攻击之后
  • .NET 8 跨平台高性能边缘采集网关
  • .NET/C# 的字符串暂存池
  • .NET开源项目介绍及资源推荐:数据持久层
  • .net流程开发平台的一些难点(1)
  • /dev下添加设备节点的方法步骤(通过device_create)
  • /tmp目录下出现system-private文件夹解决方法
  • @DependsOn:解析 Spring 中的依赖关系之艺术
  • @RequestParam,@RequestBody和@PathVariable 区别
  • @value 静态变量_Python彻底搞懂:变量、对象、赋值、引用、拷贝
  • [ 攻防演练演示篇 ] 利用通达OA 文件上传漏洞上传webshell获取主机权限
  • [【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器