当前位置: 首页 > news >正文

Python数据分析实战之:特征重要性分析

文章目录

  • 提醒
  • 代码
    • 数据处理
    • 分离 data 和 label
    • 训练
    • 训练结果 & 混淆矩阵
    • 各种 feature 的重要性

提醒

  • pandas 读取 excel 文件,需要 xlrd >= 1.1.0

代码

import sklearn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 混淆矩阵
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_curve, f1_score, precision_score, recall_score
from sklearn.svm import SVC

## pandas 显示全部单元格

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

数据处理

path = "./2019-2020年.xlsx"

# df = pd.read_csv(path, error_bad_lines=False)
df = pd.read_excel(path)

# 因为 OTT, TOAST 空值太多,暂不用于分析
df = df.drop(['OTT', 'TOAST subtypes'],axis=1)


# 其余的缺失值比较少,直接删除空值
df = df.dropna(axis=0, how='any')

df = df.reset_index()

# 有些列中的数据是字符串,要转成 int 或者 float 才能训练
# 看哪些数据是 object 类型的,全部转成 int 或者 float64 型

# 只有这一个数据有问题,转成 0 即可
df['Coronary heart disease'][74] = 0

# 数据很干净
df.head(10)
indexSexMedication before thrombolytic therapyAgeAge.1Periventricular White MatterDeep White MatterThe degree of WMHSmokingDrinkingAtrialFibrillationHypertensionDiabetesHyperlipidemiaCoronary heart diseaseHeart failureStrokeTIAWBCNLNLRHBPLTPCVPTINRAPTTTTFibrinogenEmergency blood sugarFasting blood glucoseCreatinineHDLLDLHDL/LDL比值CholesterolTriglycerideHBLACHCYDNTsBPdBPBaseline NIHSS scoreHemorrhagic Transformation(HT)Early neurological deterioration (END)Prognosis&0 (mRS0-2:0;3-6:1)90dmRSPrognosis&1(mRS0-1:0;2-6:1)
050258.00.02121.01.00.01000000.08.45.442.352.3100001412663810.70.9234.514.93.525.685.1166.20.923.150.2900004.271.994.611.025.0177.090.0800151
160253.00.02121.01.00.01000000.011.66.54.041.608911145259459.90.8630.615.33.346.386.1374.01.464.540.3215866.112.395.511.337.0166.098.0400010
2101277.01.03330.00.00.00000010.07.836.630.5911.237288103225309.20.802814.94.821.767.11107.91.191.031.1553402.280.415.013.030.0150.090.0900000
3130265.00.01011.00.00.00000000.013.8410.32.334.4206011563274711.81.0238.312.42.698.224.9885.00.722.270.3171813.511.345.712.043.0150.0102.0900021
4261266.00.02330.00.00.01100010.04.7311.62.564.5312501422814012.51.0826.118.22.496.364.0295.00.822.940.2789124.141.675.922.017.0147.075.0300021
5270274.01.03330.00.00.01001000.06.185.070.549.3888891101523212.31.0626.716.13.3310.407.0780.31.112.710.4095944.200.976.719.022.0125.080.0500000
6301270.01.01111.01.00.00000000.08.56.191.613.8447201362584011.30.9731.414.14.295.984.4054.31.133.590.3147635.281.005.415.015.0147.089.0200000
7340158.00.01111.01.00.01100000.09.486.851.733.9595381522034710.10.8731.212.74.4310.468.3772.81.423.550.4000004.961.518.613.443.0160.0105.0400010
8361265.00.01110.00.01.00000010.015.611.523.063.7647061422564311.20.9631.411.64.637.324.7550.61.082.020.5346533.521.026.521.015.0110.076.01100131
9380282.01.03331.00.00.01000000.010.318.121.246.5483871252473711.30.9731.412.54.796.254.8872.21.652.560.6445314.210.775.618.725.0220.0104.0400010

分离 data 和 label

# 看分布
label_1 = df.columns[-3]
label_2 = df.columns[-1]

# df[label_1].hist()

# df[label_2].hist()

# 得到数据集 Data 是训练数据;label1_data 是第一种标签的标签数据; label2_data 是第二种标签的标签数据

label1_data = df[label_1]
label2_data = df[label_2]


data = df[[column for column in df.columns if column not in [label_1, label_2]]]
# 去除 index 这一列,这列数据没有意义
data.drop(['index'], axis=1, inplace=True)

data.drop(['90dmRS'], axis=1, inplace=True)

data.head(10)
SexMedication before thrombolytic therapyAgeAge.1Periventricular White MatterDeep White MatterThe degree of WMHSmokingDrinkingAtrialFibrillationHypertensionDiabetesHyperlipidemiaCoronary heart diseaseHeart failureStrokeTIAWBCNLNLRHBPLTPCVPTINRAPTTTTFibrinogenEmergency blood sugarFasting blood glucoseCreatinineHDLLDLHDL/LDL比值CholesterolTriglycerideHBLACHCYDNTsBPdBPBaseline NIHSS scoreHemorrhagic Transformation(HT)Early neurological deterioration (END)
00258.00.02121.01.00.01000000.08.45.442.352.3100001412663810.70.9234.514.93.525.685.1166.20.923.150.2900004.271.994.611.025.0177.090.0800
10253.00.02121.01.00.01000000.011.66.54.041.608911145259459.90.8630.615.33.346.386.1374.01.464.540.3215866.112.395.511.337.0166.098.0400
21277.01.03330.00.00.00000010.07.836.630.5911.237288103225309.20.802814.94.821.767.11107.91.191.031.1553402.280.415.013.030.0150.090.0900
30265.00.01011.00.00.00000000.013.8410.32.334.4206011563274711.81.0238.312.42.698.224.9885.00.722.270.3171813.511.345.712.043.0150.0102.0900
41266.00.02330.00.00.01100010.04.7311.62.564.5312501422814012.51.0826.118.22.496.364.0295.00.822.940.2789124.141.675.922.017.0147.075.0300
50274.01.03330.00.00.01001000.06.185.070.549.3888891101523212.31.0626.716.13.3310.407.0780.31.112.710.4095944.200.976.719.022.0125.080.0500
61270.01.01111.01.00.00000000.08.56.191.613.8447201362584011.30.9731.414.14.295.984.4054.31.133.590.3147635.281.005.415.015.0147.089.0200
70158.00.01111.01.00.01100000.09.486.851.733.9595381522034710.10.8731.212.74.4310.468.3772.81.423.550.4000004.961.518.613.443.0160.0105.0400
81265.00.01110.00.01.00000010.015.611.523.063.7647061422564311.20.9631.411.64.637.324.7550.61.082.020.5346533.521.026.521.015.0110.076.01100
90282.01.03331.00.00.01000000.010.318.121.246.5483871252473711.30.9731.412.54.796.254.8872.21.652.560.6445314.210.775.618.725.0220.0104.0400

训练

def train(model, dataset, labelset):
    x_train, x_test, y_train, y_test = train_test_split(dataset.values
                                                        , labelset.values
                                                        , test_size=0.2
                                                        , train_size=0.8
                                                        , shuffle=True
                                                        , stratify=labelset)
    model.fit(x_train, y_train)
    score = model.score(x_test, y_test)
    accs = cross_val_score(model, dataset.values, labelset.values, verbose=0)
    print(f'validation acc is: {score}')
    print(f'cross validation accs are: {accs}')
    
    y_pre = model.predict(x_test)
    metri = confusion_matrix(y_test, y_pre)
    sns.heatmap(metri, annot=True)
    plt.show()
    

训练结果 & 混淆矩阵

svc1 = SVC(class_weight='balanced', kernel='linear')
train(svc1, data, label1_data)
validation acc is: 0.8863636363636364
cross validation accs are: [0.90909091 0.69767442 0.76744186 0.88372093 0.88372093]

在这里插入图片描述

svc2 = SVC(class_weight='balanced', kernel='linear')
train(svc2, data, label2_data)
validation acc is: 0.8863636363636364
cross validation accs are: [0.81818182 0.6744186  0.72093023 0.76744186 0.76744186]

在这里插入图片描述

lr1 = LogisticRegression(class_weight='balanced', max_iter=10000)
train(lr1, data, label1_data)
validation acc is: 0.9090909090909091
cross validation accs are: [0.81818182 0.6744186  0.76744186 0.88372093 0.90697674]

在这里插入图片描述

lr2 = LogisticRegression(class_weight='balanced', max_iter=10000)
train(lr2, data, label2_data)
validation acc is: 0.8863636363636364
cross validation accs are: [0.79545455 0.69767442 0.74418605 0.6744186  0.81395349]

在这里插入图片描述

各种 feature 的重要性

def make_coef_dictNdf(data_columns, coef):
    name_influence_dic = {string: imp for string, imp in zip(data_columns, coef.squeeze())}
    name_influence_df = pd.DataFrame(data=name_influence_dic, index=['influence']).T
    return name_influence_dic, name_influence_df
## label2 结果各个特征的重要性
def write(filename, name_df_dic):
    writer = pd.ExcelWriter(filename)
    for k,v in name_df_dic.items():
        v.to_excel(writer, sheet_name=k)
    writer.save()
    writer.close()

# writer = pd.ExcelWriter("逻辑回归.xlsx")
# label1_df.to_excel(writer, sheet_name="label1")
# label2_df.to_excel(writer, sheet_name="label2")
# writer.save()
# writer.close()

def plot(figsize, name_influence_df_lst, img_label_lst, title):
    plt.figure(figsize=figsize)
    for i in range(len(name_influence_df_lst)):
        df = name_influence_df_lst[i]
        plt.bar(x=df.index, height=df['influence'],label=img_label_lst[i])
    plt.title(title)
    plt.legend()
    plt.xticks(rotation=90)
# plt.figure(figsize=(20,10))
# plt.bar(x=label1_df.index, height=label1_df['influence'],label='label1')
# plt.bar(x=label1_df.index, height=label2_df['influence'], label='label2')
# plt.legend()
# plt.xticks(rotation=90) # 旋转90度
lr1_dic, lr1_df = make_coef_dictNdf(data.columns, lr1.coef_)
lr2_dic, lr2_df = make_coef_dictNdf(data.columns, lr2.coef_)
svc1_dic, svc1_df = make_coef_dictNdf(data.columns, svc1.coef_)
svc2_dic, svc2_df = make_coef_dictNdf(data.columns, svc2.coef_)
write("逻辑回归.xlsx", {'label1': lr1_df, 'label2': lr2_df})
write("SVM.xlsx", {'label1': svc1_df, 'label2': svc2_df})
plot((20,10), [lr1_df, lr2_df], ['label1', 'label2'], 'lr')

在这里插入图片描述

plot((20,10), [svc1_df, svc2_df], ['label1', 'label2'], 'svm')

在这里插入图片描述

相关文章:

  • 40.讲初识动态规划:如何巧妙解决“双十一”购物时的凑单问题
  • 信息学奥赛中的STL(标准模板库)--2022.09.30
  • 量子力学摘记3
  • C++11详解
  • vue3+TS实现简易组件库
  • 【深度学习100例】—— Python+OpenCV+MediaPipe实时人流检测 | 第3例
  • Mysql和ES数据同步方案汇总
  • Java / Tensorflow - API 调用 pb 模型使用 GPU 推理
  • 【CSS】精灵图 背景图 阴影 过渡
  • 【设计模式】【第五章】【开具增值税发票】【建造者模式 + 原型模式】
  • 【关于Linux中权限管理】
  • Opencv项目实战:11 使用Opencv高亮显示文本检测
  • 零基础转行,入职军工类测试方向,月薪10K | 既然选择了,就要全力以赴
  • python字典与集合还有数据类型转换
  • CH559L单片机ADC多通道采样数据串口打印案例
  • @angular/forms 源码解析之双向绑定
  • 【140天】尚学堂高淇Java300集视频精华笔记(86-87)
  • 8年软件测试工程师感悟——写给还在迷茫中的朋友
  • Apache Zeppelin在Apache Trafodion上的可视化
  • GDB 调试 Mysql 实战(三)优先队列排序算法中的行记录长度统计是怎么来的(上)...
  • JAVA 学习IO流
  • Javascript 原型链
  • leetcode46 Permutation 排列组合
  • mysql外键的使用
  • Node.js 新计划:使用 V8 snapshot 将启动速度提升 8 倍
  • vue脚手架vue-cli
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 开放才能进步!Angular和Wijmo一起走过的日子
  • 如何优雅地使用 Sublime Text
  • 算法---两个栈实现一个队列
  • 小而合理的前端理论:rscss和rsjs
  • 一起来学SpringBoot | 第十篇:使用Spring Cache集成Redis
  • 国内开源镜像站点
  • 进程与线程(三)——进程/线程间通信
  • ​3ds Max插件CG MAGIC图形板块为您提升线条效率!
  • # 学号 2017-2018-20172309 《程序设计与数据结构》实验三报告
  • $$$$GB2312-80区位编码表$$$$
  • (12)Hive调优——count distinct去重优化
  • (欧拉)openEuler系统添加网卡文件配置流程、(欧拉)openEuler系统手动配置ipv6地址流程、(欧拉)openEuler系统网络管理说明
  • (全部习题答案)研究生英语读写教程基础级教师用书PDF|| 研究生英语读写教程提高级教师用书PDF
  • (三)mysql_MYSQL(三)
  • (三)模仿学习-Action数据的模仿
  • (转)nsfocus-绿盟科技笔试题目
  • (转)scrum常见工具列表
  • (转)详解PHP处理密码的几种方式
  • (转)总结使用Unity 3D优化游戏运行性能的经验
  • ./和../以及/和~之间的区别
  • .class文件转换.java_从一个class文件深入理解Java字节码结构
  • .NET Framework .NET Core与 .NET 的区别
  • .NET国产化改造探索(一)、VMware安装银河麒麟
  • .NET学习教程二——.net基础定义+VS常用设置
  • .pop ----remove 删除
  • @31省区市高考时间表来了,祝考试成功
  • @for /l %i in (1,1,10) do md %i 批处理自动建立目录
  • @SuppressLint(NewApi)和@TargetApi()的区别