当前位置: 首页 > news >正文

机器学习第8天:SVM分类

文章目录

机器学习专栏

介绍

特征缩放

示例代码

硬间隔与软间隔分类

主要代码

代码解释

非线性SVM分类

结语


机器学习专栏

机器学习_Nowl的博客-CSDN博客

介绍

作用:判别种类

原理:找出一个决策边界,判断数据所处区域来识别种类

简单介绍一下SVM分类的思想,我们看下面这张图,两种分类都很不错,但是我们可以注意到第二种的决策边界与实例更远(它们之间的距离比较宽),而SVM分类就是一种寻找距每种实例最远的决策边界的算法


特征缩放

SVM算法对特征缩放很敏感(不处理算法效果会受很大影响)

特征缩放是什么意思呢,例如有身高数据和体重数据,若身高是m为单位,体重是g为单位,那么体重就比身高的数值大很多,有些机器学习算法就可能更关注某一个值,这时我们用特征缩放就可以把数据统一到相同的尺度上

示例代码

from sklearn.preprocessing import StandardScaler
import numpy as np# 创建一个示例数据集
data = np.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0],[7.0, 8.0, 9.0]])# 创建StandardScaler对象
scaler = StandardScaler()# 对数据进行标准化
scaled_data = scaler.fit_transform(data)print("原始数据:\n", data)
print("\n标准化后的数据:\n", scaled_data)# 结果是
# [[-1.22474487 -1.22474487 -1.22474487]
#  [ 0.          0.          0.        ]
#  [ 1.22474487  1.22474487  1.22474487]]

 StandardScaler是一种数据标准化的方法,它对数据进行线性变换,使得数据的均值变为0,标准差变为1。 

解释上面的数据

在每列上进行标准化,即对每个特征进行独立的标准化。每个数值是通过减去该列的均值,然后除以该列的标准差得到的。

  • 第一列:(1−4)/9=−1.22474487(1−4)/9​=−1.22474487,(4−4)/9=0(4−4)/9​=0,(7−4)/9=1.22474487(7−4)/9​=1.22474487。
  • 第二列:(2−5)/9=−1.22474487(2−5)/9​=−1.22474487,(5−5)/9=0(5−5)/9​=0,(8−5)/9=1.22474487(8−5)/9​=1.22474487。
  • 第三列:(3−6)/9=−1.22474487(3−6)/9​=−1.22474487,(6−6)/9=0(6−6)/9​=0,(9−6)/9=1.22474487(9−6)/9​=1.22474487。

这样,标准化后的数据集就符合标准正态分布,每个特征的均值为0,标准差为1。


硬间隔与软间隔分类

硬间隔分类就是完全将不同的个体区分在不同的区域(不能有一点误差)

软间隔分类就是允许一些偏差(图中绿和红色的点都有一些出现在了对方的分区里)

硬间隔分类往往会出现一些问题,例如有时候模型不可能完全分成两类,同时,硬间隔分类往往可能导致过拟合,而软间隔分类的泛化能力就比硬间隔分类好很多


主要代码

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVCmodel = Pipeline([("scaler", StandardScaler()),("linear_svc", LinearSVC(C=1, loss="hinge"))
])model.fit(x, y)

代码解释

在这里,Pipeline的构造函数接受一个由元组组成的列表。每个元组的第一个元素是该步骤的名称(字符串),第二个元素是该步骤的实例。在这个例子中,第一个步骤是数据标准化,使用StandardScaler,命名为"scaler";第二个步骤是线性支持向量机,使用LinearSVC,命名为"linear_svc"。这两个步骤会按照列表中的顺序依次执行。

参数C是正则程度,hinge是SVM分类算法的损失函数,用来训练模型


非线性SVM分类

上述方法都是在数据集可线性分离时用到的,当数据集呈非线性怎么办,我们在回归任务中讲过一个思想,用PolynomialFeatures来产生多项式,再对每个项进行线性拟合,最后结合在一起得出决策边界

具体代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.preprocessing import PolynomialFeatures
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score# 生成非线性数据集
X, y = datasets.make_circles(n_samples=100, factor=0.5, noise=0.1, random_state=42)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 使用多项式特征和线性SVM
degree = 3  # 多项式的次数
svm_classifier = make_pipeline(StandardScaler(), PolynomialFeatures(degree), SVC(kernel='linear', C=1))
svm_classifier.fit(X_train, y_train)# 预测并计算准确率
y_pred = svm_classifier.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)# 绘制决策边界
def plot_decision_boundary(X, y, model, ax):h = .02x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))Z = model.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)ax.contourf(xx, yy, Z, alpha=0.8)ax.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o', s=80, linewidth=0.5)ax.set_xlim(xx.min(), xx.max())ax.set_ylim(yy.min(), yy.max())# 绘制结果
fig, ax = plt.subplots(figsize=(8, 6))
plot_decision_boundary(X_train, y_train, svm_classifier, ax)
ax.set_title('Polynomial SVM Decision Boundary')
plt.show()

运行结果


结语

SVM分类是一种经典的分类算法,也叫大间隔分类算法。它可以用来线性分类,也可以非线性分类(可以与PolynomialFeatures结合,当然还有其他方法,我们之后再说)

相关文章:

  • 【论文阅读】A Survey on Video Diffusion Models
  • Linux--网络概念
  • ZJU Beamer学习手册(二)
  • 全志XR806基于http的无线ota功能实验
  • 创新研报|新业务发展是CEO推动企业增长的必要选择 – Mckinsey研究
  • 音视频项目—基于FFmpeg和SDL的音视频播放器解析(十)
  • android开发连接网络
  • Leetcode—141.环形链表【简单】
  • csapp深入理解计算机系统 bomb lab(1)phase_1
  • Redis数据的持久化
  • SpringCloud Alibaba详解
  • NoSQL 与传统数据库的集成
  • WPF中如何在MVVM模式下关闭窗口
  • 大数据Doris(二十六):数据导入(Routine Load)介绍
  • 【大数据分布并行处理】单元测试(五)
  • 实现windows 窗体的自己画,网上摘抄的,学习了
  • 【Redis学习笔记】2018-06-28 redis命令源码学习1
  • 78. Subsets
  • css选择器
  • js正则,这点儿就够用了
  • mysql_config not found
  • PHP 7 修改了什么呢 -- 2
  • Python十分钟制作属于你自己的个性logo
  • React 快速上手 - 07 前端路由 react-router
  • Sublime text 3 3103 注册码
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • 创建一种深思熟虑的文化
  • 从PHP迁移至Golang - 基础篇
  • 对JS继承的一点思考
  • 翻译:Hystrix - How To Use
  • 目录与文件属性:编写ls
  • 深度解析利用ES6进行Promise封装总结
  • 使用 Docker 部署 Spring Boot项目
  • 一道面试题引发的“血案”
  • 用简单代码看卷积组块发展
  • ​Kaggle X光肺炎检测比赛第二名方案解析 | CVPR 2020 Workshop
  • ### Error querying database. Cause: com.mysql.jdbc.exceptions.jdbc4.CommunicationsException
  • $.ajax()方法详解
  • (20050108)又读《平凡的世界》
  • (C#)Windows Shell 外壳编程系列4 - 上下文菜单(iContextMenu)(二)嵌入菜单和执行命令...
  • (C语言)二分查找 超详细
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (七)理解angular中的module和injector,即依赖注入
  • (学习日记)2024.01.09
  • (原創) 如何刪除Windows Live Writer留在本機的文章? (Web) (Windows Live Writer)
  • .[hudsonL@cock.li].mkp勒索病毒数据怎么处理|数据解密恢复
  • .NET CORE 3.1 集成JWT鉴权和授权2
  • .NET Core工程编译事件$(TargetDir)变量为空引发的思考
  • .NET 设计模式—适配器模式(Adapter Pattern)
  • .NET高级面试指南专题十一【 设计模式介绍,为什么要用设计模式】
  • .Net高阶异常处理第二篇~~ dump进阶之MiniDumpWriter
  • .NET项目中存在多个web.config文件时的加载顺序
  • .pyc文件还原.py文件_Python什么情况下会生成pyc文件?
  • @Documented注解的作用
  • @取消转义