当前位置: 首页 > news >正文

解决加载模型预测数据时报错的问题

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""

from sklearn.externals import joblib
import xgboost as xgb
loaded_model = joblib.load("xgb_best_model_0731_joblib.model")



import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder


class MultiOneHotEncoder:
    def __init__(self, df, column_name_list):
        self.df = df
        self.column_name_list = column_name_list

    def multi_column_encoder(self):
        Enc_ohe, Enc_label = OneHotEncoder(), LabelEncoder()
        for column_name in self.column_name_list:
            self.df["Dummies"] = Enc_label.fit_transform(self.df[column_name])
            self.df_dummies = pd.DataFrame(Enc_ohe.fit_transform(self.df[["Dummies"]]).todense(),
                                           columns=Enc_label.classes_)
            self.df_dummies.rename(columns=lambda x: column_name + "_" + x, inplace=True)
            self.df = pd.concat([self.df, self.df_dummies], axis=1)
        self.df.drop(["Dummies"], axis=1, inplace=True)
        self.df.drop(self.column_name_list, axis=1, inplace=True)
        return self.df
    
    
test = pd.read_csv("test_info_0903_001", sep=',', header=None, converters={0:str}, na_values=['Null','null','NULL'])
column_list = ["user_id","position_status","gender","highest_degree","age","seniority","marital_status",
                   "latest_position","next-to-last_position","last-but-two_position","leave_hours",
                   "recent-30days_leave_times","recent-60days_leave_times",
                   "recent-30days_email_outbreaks","recent-30days_single_email_outbreaks",
                   "recent-60days_email_outbreaks","recent-60days_single_email_outbreaks"]

test.columns = column_list

drop_column_list = ['latest_position', 'next-to-last_position', 'last-but-two_position']
test.drop(drop_column_list, axis=1, inplace=True)

test = test[(test['age']>=18) & (test['age']<80) & (~test['highest_degree'].isin(['初中及以下']))]


values = {'highest_degree':'missing'}
test = test.fillna(value=values)

test['position_status'] = test['position_status'].apply(lambda x: 1 if x=='离职' else 0)
test = test.reset_index(drop=True)

string_column = test.loc[:, test.dtypes == 'object'].columns
column_name_list = list(string_column)
remove_column_list = ['user_id']
for var in remove_column_list:
    column_name_list.remove(var)
    
    
data = MultiOneHotEncoder(test, column_name_list).multi_column_encoder()

X_pred = data.drop(['user_id', 'position_status'], axis=1)

"""
xgb_x = xgb.DMatrix(X_pred)
y_pred = loaded_model.predict(xgb_x, ntree_limit=loaded_model.best_ntree_limit)
result = pd.DataFrame({'Actual': data.position_status, 'Prob': y_pred})
""" 
############
column_list = loaded_model.feature_names  ##输出模型特征
new_X_pred = X_pred.reindex(columns=column_list, fill_value=0)  ###有点类似多退少补的概念
##One-hot之后不论是缺少字段,还是多出字段,用此种方式均可以正确输出预测值,不会报错

new_xgb_x = xgb.DMatrix(new_X_pred)
new_y_pred = loaded_model.predict(new_xgb_x, ntree_limit=loaded_model.best_ntree_limit)  ##加载best_ntree_limit

result = pd.DataFrame({'Actual': data.position_status, 'Prob': new_y_pred})
result['user_id'] = data['user_id'] ##

result.to_csv("just_test_pred_info_0903_001.csv", index=False)


转载于:https://my.oschina.net/kyo4321/blog/1941587

相关文章:

  • java 颠倒整数
  • 【火炉炼AI】机器学习022-使用均值漂移聚类算法构建模型
  • Python从菜鸟到高手(5):数字
  • python中的None
  • eclipse 执行自带的maven命令无效
  • 【转载三】Grafana系列教程–Grafana的配置及运行
  • mysql 通过备份和binlog恢复数据
  • java类加载时机与过程
  • 设计模式走一遍---观察者模式
  • 我发起了一个 .Net 平台上的 产生式编程 开源项目 GP.Net
  • windows远程连接报:身份错误,函数不支持的解决办法
  • Docker 笔记(2):Dockerfile
  • promise原理就是这么简单
  • EXE文件执行过程中发生了什么?
  • MathExam小学一二年级计算题生成器V1.0
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • CSS中外联样式表代表的含义
  • Django 博客开发教程 16 - 统计文章阅读量
  • Docker: 容器互访的三种方式
  • Java程序员幽默爆笑锦集
  • npx命令介绍
  • overflow: hidden IE7无效
  • SQLServer之创建数据库快照
  • TCP拥塞控制
  • WePY 在小程序性能调优上做出的探究
  • 阿里云前端周刊 - 第 26 期
  • 规范化安全开发 KOA 手脚架
  • 诡异!React stopPropagation失灵
  • 基于MaxCompute打造轻盈的人人车移动端数据平台
  • 数据库写操作弃用“SELECT ... FOR UPDATE”解决方案
  • 怎么把视频里的音乐提取出来
  • 数据库巡检项
  • 完善智慧办公建设,小熊U租获京东数千万元A+轮融资 ...
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • #git 撤消对文件的更改
  • ( 用例图)定义了系统的功能需求,它是从系统的外部看系统功能,并不描述系统内部对功能的具体实现
  • (2)(2.10) LTM telemetry
  • (C语言)深入理解指针2之野指针与传值与传址与assert断言
  • (pojstep1.1.2)2654(直叙式模拟)
  • (Redis使用系列) Springboot 使用redis实现接口幂等性拦截 十一
  • (二)JAVA使用POI操作excel
  • (二)换源+apt-get基础配置+搜狗拼音
  • (附源码)ssm跨平台教学系统 毕业设计 280843
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (三十五)大数据实战——Superset可视化平台搭建
  • (十一)手动添加用户和文件的特殊权限
  • (四)Controller接口控制器详解(三)
  • (一)基于IDEA的JAVA基础10
  • (转)拼包函数及网络封包的异常处理(含代码)
  • (最简单,详细,直接上手)uniapp/vue中英文多语言切换
  • *ST京蓝入股力合节能 着力绿色智慧城市服务
  • .NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式
  • @require_PUTNameError: name ‘require_PUT‘ is not defined 解决方法
  • [ vulhub漏洞复现篇 ] AppWeb认证绕过漏洞(CVE-2018-8715)