当前位置: 首页 > news >正文

基于Python的机器学习系列(18):梯度提升分类(Gradient Boosting Classification)

简介

        梯度提升(Gradient Boosting)是一种集成学习方法,通过逐步添加新的预测器来改进模型。在回归问题中,我们使用梯度来最小化残差。在分类问题中,我们可以利用梯度提升来进行二分类或多分类任务。与回归不同,分类问题需要使用如softmax这样的概率模型来处理类别标签。

梯度提升分类的工作原理

        梯度提升分类的基本步骤与回归类似,但在分类任务中,我们使用概率模型来处理预测结果:

  1. 初始化模型:选择一个初始预测器,这里使用DummyClassifier来作为第一个模型。
  2. 计算梯度:计算每个样本的梯度,梯度是当前预测值与真实标签之间的差异。
  3. 训练新预测器:用计算得到的梯度作为目标,训练一个新的分类器。
  4. 更新模型:将新预测器的结果加到现有模型中。
  5. 重复步骤:重复上述步骤,逐步添加更多的预测器以改进模型的分类能力。

二分类示例

        在二分类任务中,梯度提升分类器的工作流程如下:

  1. 预测概率:通过softmax将预测值转换为概率。
  2. 更新模型:利用当前的梯度来训练下一个分类器。

代码示例

        下面的代码示例展示了如何实现一个梯度提升分类器,包括支持二分类和多分类任务:

from sklearn.tree import DecisionTreeRegressor
from sklearn.dummy import DummyRegressor, DummyClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits, load_breast_cancer
import numpy as npclass GradientBoosting:def __init__(self, S=5, learning_rate=1, max_depth=1, min_samples_split=2, regression=True, tol=1e-4):self.S = Sself.learning_rate = learning_rateself.max_depth = max_depthself.min_samples_split = min_samples_splitself.regression = regression# 初始化回归树tree_params = {'max_depth': self.max_depth, 'min_samples_split': self.min_samples_split}self.models = [DecisionTreeRegressor(**tree_params) for _ in range(S)]if regression:# 回归模型的初始模型self.models.insert(0, DummyRegressor(strategy='mean'))else:# 分类模型的初始模型self.models.insert(0, DummyClassifier(strategy='most_frequent'))def grad(self, y, h):return y - hdef fit(self, X, y):# 训练第一个模型self.models[0].fit(X, y)for i in range(self.S):# 预测yhat = self.predict(X, self.models[:i+1], with_argmax=False)# 计算梯度gradient = self.grad(y, yhat)# 训练下一个模型self.models[i+1].fit(X, gradient)def predict(self, X, models=None, with_argmax=True):if models is None:models = self.modelsh0 = models[0].predict(X)boosting = sum(self.learning_rate * model.predict(X) for model in models[1:])yhat = h0 + boostingif not self.regression:# 使用softmax转换为概率yhat = np.exp(yhat) / np.sum(np.exp(yhat), axis=1, keepdims=True)if with_argmax:yhat = np.argmax(yhat, axis=1)return yhat# 示例:使用乳腺癌数据集进行二分类
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建和训练梯度提升分类器
gb = GradientBoosting(S=50, learning_rate=0.1, regression=False)
gb.fit(X_train, y_train)# 预测并计算准确率
y_pred = gb.predict(X_test)
from sklearn.metrics import accuracy_score
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')

总结

        梯度提升分类器通过逐步减少分类错误来提高模型的性能。这种方法在处理分类任务时,能够有效提高预测准确率。与回归任务类似,分类任务中的梯度提升也能通过逐步添加预测器来优化模型。通过调整学习率和模型参数,我们可以进一步提高模型的表现。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • RabbitMQ练习(Remote procedure call (RPC))
  • 筛法求欧拉函数
  • 问:说一下Java中数组的实例化方式有哪些?
  • Java-数据结构-包装类和认识泛型 !!!∑(゚Д゚ノ)ノ
  • Java Stream流式编程
  • 小程序自定义组件配合插槽和组件传值
  • 重生之我们在ES顶端相遇第11 章 - 深入自定义语言分词器
  • centos 系统yum 安装 mariadb
  • 书生大模型实战营基础(3)——LangGPT结构化提示词编写实践
  • C#基础(2)枚举
  • Linux系统安装MySQL8.0
  • ES6更新的内容中什么是proxy
  • 力扣8.29
  • React多功能管理平台项目开发全教程
  • C++ | Leetcode C++题解之第387题字符串中的第一个唯一字符
  • SegmentFault for Android 3.0 发布
  • [ 一起学React系列 -- 8 ] React中的文件上传
  • 【391天】每日项目总结系列128(2018.03.03)
  • 【从零开始安装kubernetes-1.7.3】2.flannel、docker以及Harbor的配置以及作用
  • 【跃迁之路】【477天】刻意练习系列236(2018.05.28)
  • classpath对获取配置文件的影响
  • Idea+maven+scala构建包并在spark on yarn 运行
  • JavaScript-Array类型
  • jquery cookie
  • PHP 7 修改了什么呢 -- 2
  • QQ浏览器x5内核的兼容性问题
  • Vue 2.3、2.4 知识点小结
  • vue--为什么data属性必须是一个函数
  • 关于for循环的简单归纳
  • 入门级的git使用指北
  • 设计模式走一遍---观察者模式
  • 使用docker-compose进行多节点部署
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • MPAndroidChart 教程:Y轴 YAxis
  • 蚂蚁金服CTO程立:真正的技术革命才刚刚开始
  • ​​​【收录 Hello 算法】9.4 小结
  • ​DB-Engines 12月数据库排名: PostgreSQL有望获得「2020年度数据库」荣誉?
  • ​TypeScript都不会用,也敢说会前端?
  • !!java web学习笔记(一到五)
  • # 利刃出鞘_Tomcat 核心原理解析(二)
  • #Datawhale AI夏令营第4期#多模态大模型复盘
  • (6)添加vue-cookie
  • (Redis使用系列) Springboot 实现Redis 同数据源动态切换db 八
  • (非本人原创)我们工作到底是为了什么?​——HP大中华区总裁孙振耀退休感言(r4笔记第60天)...
  • (附源码)计算机毕业设计ssm基于B_S的汽车售后服务管理系统
  • (附源码)计算机毕业设计大学生兼职系统
  • (离散数学)逻辑连接词
  • (每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理 第13章 项目资源管理(七)
  • (七)Knockout 创建自定义绑定
  • (五)activiti-modeler 编辑器初步优化
  • (学习日记)2024.01.09
  • (自用)learnOpenGL学习总结-高级OpenGL-抗锯齿
  • .java 9 找不到符号_java找不到符号
  • .MyFile@waifu.club.wis.mkp勒索病毒数据怎么处理|数据解密恢复
  • .NET 中创建支持集合初始化器的类型