当前位置: 首页 > news >正文

xgboost回归损失函数自定义【一】

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

写在前面:

每当提到损失函数,很多人都有个误解,以为用在GridSearchCV(网格搜索交叉验证“Cross Validation”)里边的scoring就是损失函数,其实并不是。我们使用构造函数构造XGBRegressor的时候,里边的objective参数才是真正的损失函数(loss function)。xgb使用sklearn api的时候需要用到的损失函数,其返回值是一阶导和二阶导,而GridSearchCV使用的scoring函数,返回的是一个float类型的数值评分(或叫准确率、或叫偏差值)。

You should be careful with the notation.

There are 2 levels of optimization here:

  1. The loss function optimized when the XGBRegressor is fitted to the data.
  2. The scoring function that is optimized during the grid search.

I prefer calling the second scoring function instead of loss function, since loss function usually refers to a term that is subject to optimization during the model fitting process itself.

Scikit-Learn: Custom Loss Function for GridSearchCV

因此,下文对于objective,统一叫“目标函数”;而对scoring,统一叫“评价函数”。

 

 

========== 原文分割线 ===================

许多特定的任务需要定制目标函数,来达到更优的效果。这里以xgboost的回归预测为例,介绍一下objective函数的定制过程。一个简单的例子如下:

def customObj1(real, predict):
    grad = predict - real
    hess = np.power(np.abs(grad), 0.5)
    return grad, hess

网上有许多教程定义的objective函数中的第一个参数是preds,第二个是dtrain,而本文由于使用xgboost的sklearn API,因此定制的objective函数需要与sklearn的格式相符。调用目标函数的过程如下:

model = xgb.XGBRegressor(objective=customObj1,
                         booster="gblinear")

下面是不同迭代次数的动画演示:

我们发现,不同的目标函数对模型的收敛速度影响较大,但最终收敛目标大致相同,如下图:

完整代码如下:

# coding=utf-8
import pandas as pd
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt

plt.rcParams.update({'figure.autolayout': True})

df = pd.DataFrame({'x': [-2.1, -0.9,  0,  1,  2, 2.5,  3,  4],
                   'y': [ -10,    0, -5, 10, 20,  10, 30, 40]})
X_train = df.drop('y', axis=1)
Y_train = df['y']
X_pred = [-4, -3, -2, -1, 0, 0.4, 0.6, 1, 1.4, 1.6, 2, 3, 4, 5, 6, 7, 8]


def process_list(list_in):
    result = map(lambda x: "%8.2f" % round(float(x), 2), list_in)
    return list(result)


def customObj3(real, predict):
    grad = predict - real
    hess = np.power(np.abs(grad), 0.1)
    # print 'predict', process_list(predict.tolist()), type(predict)
    # print ' real  ', process_list(real.tolist()), type(real)
    # print ' grad  ', process_list(grad.tolist()), type(grad)
    # print ' hess  ', process_list(hess.tolist()), type(hess), '\n'
    return grad, hess


def customObj1(real, predict):
    grad = predict - real
    hess = np.power(np.abs(grad), 0.5)

    return grad, hess


for n_estimators in range(5, 600, 5):
    booster_str = "gblinear"
    model = xgb.XGBRegressor(objective=customObj1,
                             booster=booster_str,
                             n_estimators=n_estimators)
    model2 = xgb.XGBRegressor(objective="reg:linear",
                              booster=booster_str,
                              n_estimators=n_estimators)
    model3 = xgb.XGBRegressor(objective=customObj3,
                              booster=booster_str,
                              n_estimators=n_estimators)
    model.fit(X=X_train, y=Y_train)
    model2.fit(X=X_train, y=Y_train)
    model3.fit(X=X_train, y=Y_train)

    y_pred = model.predict(data=pd.DataFrame({'x': X_pred}))
    y_pred2 = model2.predict(data=pd.DataFrame({'x': X_pred}))
    y_pred3 = model3.predict(data=pd.DataFrame({'x': X_pred}))

    plt.figure(figsize=(6, 5))
    plt.axes().set(title='n_estimators='+str(n_estimators))

    plt.plot(df['x'], df['y'], marker='o', linestyle=":", label="Real Y")
    plt.plot(X_pred, y_pred, label="predict - real; |grad|**0.5")
    plt.plot(X_pred, y_pred3, label="predict - real; |grad|**0.1")
    plt.plot(X_pred, y_pred2, label="reg:linear")

    plt.xlim(-4.5, 8.5)
    plt.ylim(-25, 55)

    plt.legend()
    # plt.show()
    plt.savefig("output/n_estimators_"+str(n_estimators)+".jpg")
    plt.close()
    print(n_estimators)

 

转载于:https://my.oschina.net/u/2996334/blog/3006786

相关文章:

  • Java null最佳实践
  • 36氪首发|「优仕美地医疗」获亿元级B轮融资,要打造日间手术机构的连锁服务网络...
  • 阿里云联合8家芯片商推“全平台通信模组”,加速物联网生态建设
  • MySQL设置主从复制
  • 赶紧收藏!新鲜出炉的重庆轨道交通图 首末班时间和线路都在里面
  • 厉害!重庆参加马拉松赛人数7年翻10倍,今年区县马拉松赛事将大增
  • python教程(一)·命令行基本操作
  • TCP三次握手四次挥手
  • C++类中的特殊成员函数
  • ES搜索引擎集群模式搭建【Kibana可视化】
  • spring cloud gateway 源码解析(4)跨域问题处理
  • 有赞电商云应用框架设计
  • JS专题之继承
  • 阿里云服务器怎么升级配置?升级有哪些限制?
  • UniDAC使用教程(五):数据加密
  • 10个确保微服务与容器安全的最佳实践
  • CSS中外联样式表代表的含义
  • dva中组件的懒加载
  • ES6 学习笔记(一)let,const和解构赋值
  • ES学习笔记(12)--Symbol
  • Hexo+码云+git快速搭建免费的静态Blog
  • idea + plantuml 画流程图
  • Markdown 语法简单说明
  • python_bomb----数据类型总结
  • Redis 懒删除(lazy free)简史
  • Solarized Scheme
  • Travix是如何部署应用程序到Kubernetes上的
  • vuex 学习笔记 01
  • 从0到1:PostCSS 插件开发最佳实践
  • 聊聊directory traversal attack
  • 软件开发学习的5大技巧,你知道吗?
  • 时间复杂度与空间复杂度分析
  • 用jQuery怎么做到前后端分离
  • 责任链模式的两种实现
  • 走向全栈之MongoDB的使用
  • 看到一个关于网页设计的文章分享过来!大家看看!
  • 你对linux中grep命令知道多少?
  • UI设计初学者应该如何入门?
  • ​批处理文件中的errorlevel用法
  • #1015 : KMP算法
  • (173)FPGA约束:单周期时序分析或默认时序分析
  • (C语言版)链表(三)——实现双向链表创建、删除、插入、释放内存等简单操作...
  • (附源码)springboot优课在线教学系统 毕业设计 081251
  • (附源码)ssm智慧社区管理系统 毕业设计 101635
  • (一)spring cloud微服务分布式云架构 - Spring Cloud简介
  • (转)Linux NTP配置详解 (Network Time Protocol)
  • (转)Scala的“=”符号简介
  • (转)程序员疫苗:代码注入
  • .mkp勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复
  • .net Signalr 使用笔记
  • .Net Winform开发笔记(一)
  • .NET 程序如何获取图片的宽高(框架自带多种方法的不同性能)
  • .NET 应用启用与禁用自动生成绑定重定向 (bindingRedirect),解决不同版本 dll 的依赖问题
  • .NET:自动将请求参数绑定到ASPX、ASHX和MVC(菜鸟必看)
  • .Net各种迷惑命名解释