当前位置: 首页 > news >正文

吴恩达机器学习 第二课 week4 决策树

目录

01 学习目标

02 实现工具

03 问题描述

04 构建决策树

05 总结


01 学习目标

     (1)理解“熵”、“交叉熵(信息增益)”的概念

     (2)掌握决策树的构建步骤与要点

02 实现工具

    (1)代码运行环境

              Python语言,Jupyter notebook平台

    (2)所需模块

              numpy,matplotlib,public_tests

03 问题描述

       假设你是犇犇蘑菇集团的总裁,你现在要亲自抽检10只蘑菇,看下里面有几只是毒蘑菇,程序猿出身的你打算采用决策树进行检测,Let's begin!

04 构建决策树

     (1)导入所需模块

import numpy as np
import matplotlib.pyplot as plt
from public_tests import *%matplotlib inline

      (public_tests是自定义模块,内部包括compute_entropy_test、split_dataset_test、compute_information_gain_test、get_best_split_test共4个函数,是一个 Jupyter Notebook 的魔法命令(Magic Command),用于在 Notebook 单元格中直接显示 Matplotlib 生成的图形) 

     (2)数据集

       抽检的蘑菇采用3个特征,分别是Brown Cap、Tapering Stalk Shape和Solitary,检测结果为是/否有毒。特征及结果采用独热编码(one-hot),如下表:

        其中,Brown Cap列的1表示“棕色帽”、0表示“红色帽”;Tapering Stalk Shape列的1表示“锥形茎”、0表示“扩口茎”;Solitary列的1表示“单生”、0表示“非单生”;Edible列的1表示“无毒”、0表示“有毒”。

       数据定义如下:

X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])

      (3)决策树步骤

         决策树构建分4步:①选择根节点特征;②计算所有分裂情况的信息增益并选择具有最高信息增益的特征;③根据选择的特征拆分数据集,创建树的左右分支;④继续重复分割过程,直到满足停止条件。

        信息增益(又称“交叉熵”)表示由于分裂导致的熵的变化,熵用来衡量信息混乱程度,熵大则乱,信息增益的一般计算公式如下:

Info \;\; gain=H(P_{root})-[W_{left}H(P_{left})+W_{right}H(P_{right})]

H(P_i)=-Plog_2(P_i)-(1-P_i)log_2(1-P_i)

 其中,Info gain为信息增益,H(P)为概率P的熵,P=k/nk为目标出现次数,n为总数。

      (4)代码实现决策树

         ①定义熵函数

def compute_entropy(y):entropy = 0.k = 0n = len(y)if n == 0:entropy = 0else:for i in range(n): k += y[i]       p = k / nif p == 0 or p == 1:entropy = 0else:entropy = -p * np.log2(p) - (1 - p) * np.log2(1 - p)return entropy

       ②定义分裂函数

def split_dataset(X, node_indices, feature):left_indices = []right_indices = []for id in node_indices:if X[id, feature] == 1:left_indices.append(id)else:right_indices.append(id)return left_indices, right_indices

       ③定义信息增益函数

def compute_information_gain(X, y, node_indices, feature):left_indices, right_indices = split_dataset(X, node_indices, feature)X_node, y_node = X[node_indices], y[node_indices]X_left, y_left = X[left_indices], y[left_indices]X_right, y_right = X[right_indices], y[right_indices]num_left = len(X_left)num_right = len(X_right)num_sum = num_left + num_rightw_left = num_left / num_sumw_right = num_right / num_sumentropy_w = w_left * compute_entropy(y_left) + w_right * compute_entropy(y_right)                                    information_gain = compute_entropy(y_node) - entropy_w  return information_gain

       ④定义最优分裂函数

def get_best_split(X, y, node_indices):       num_features = X.shape[1]best_feature = -1max_info_gain = 0for feature in range(num_features):info_gain = compute_information_gain(X, y, node_indices, feature)if info_gain > max_info_gain:max_info_gain = info_gainbest_feature = featurereturn best_feature

       ⑤定义决策树函数

tree = []def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth): # 停止分裂的条件if current_depth == max_depth:formatting = " "*current_depth + "-"*current_depthprint(formatting, "%s leaf node with indices" % branch_name, node_indices)returnbest_feature = get_best_split(X, y, node_indices) tree.append((current_depth, branch_name, best_feature, node_indices))formatting = "-"*current_depthprint("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))# 在最优特征处分裂left_indices, right_indices = split_dataset(X, node_indices, best_feature)# 继续分裂build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)

      (分支名称branch_name:'Root', 'Left', 'Right';formatting = "-"*current_depth用于生成与“current_depth”数量相等的“-”,用于缩进)

       ⑥开始构建决策树

build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)

       运行以上代码,结果如下:

    

       决策树的分类结果如下图所示(自己用PPT绘的):

05 总结

      (1)决策树的构建包括:计算熵、信息增益、寻找最优分裂方式3个核心要点。

      (2) 决策树要解决的是多特征分类识别问题。

相关文章:

  • 如何配置node.js环境
  • 软件设计师笔记-系统开发和运行知识(一)
  • 总结 CSS 选择器的常见用法
  • 硬盘数据恢复软件,推荐5种适合你的方法来恢复硬盘数据
  • 医学记录 --- 腋下异味
  • 手持弹幕LED滚动字幕屏夜店表白手灯接机微信抖音小程序开源版开发
  • 20-OWASP top10--XXS跨站脚本攻击
  • websocket 安全通信
  • 计算机组成入门知识
  • Memcached缓存系统详解
  • android 在线程中更新界面
  • Typora + Hexo 图片路径问题(Typedown)
  • Flink Sql Redis Connector
  • 数据结构之B数
  • 在JPA项目启动时新增MySQL字段
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • 【跃迁之路】【699天】程序员高效学习方法论探索系列(实验阶段456-2019.1.19)...
  • 2019年如何成为全栈工程师?
  • CAP 一致性协议及应用解析
  • Druid 在有赞的实践
  • flask接收请求并推入栈
  • Fundebug计费标准解释:事件数是如何定义的?
  • leetcode98. Validate Binary Search Tree
  • PHP CLI应用的调试原理
  • Spring Cloud(3) - 服务治理: Spring Cloud Eureka
  • Sublime Text 2/3 绑定Eclipse快捷键
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • 海量大数据大屏分析展示一步到位:DataWorks数据服务+MaxCompute Lightning对接DataV最佳实践...
  • 如何抓住下一波零售风口?看RPA玩转零售自动化
  • 通过来模仿稀土掘金个人页面的布局来学习使用CoordinatorLayout
  • AI又要和人类“对打”,Deepmind宣布《星战Ⅱ》即将开始 ...
  • ​Benvista PhotoZoom Pro 9.0.4新功能介绍
  • ​什么是bug?bug的源头在哪里?
  • ​学习笔记——动态路由——IS-IS中间系统到中间系统(报文/TLV)​
  • ​云纳万物 · 数皆有言|2021 七牛云战略发布会启幕,邀您赴约
  • #ifdef 的技巧用法
  • #免费 苹果M系芯片Macbook电脑MacOS使用Bash脚本写入(读写)NTFS硬盘教程
  • (3)医疗图像处理:MRI磁共振成像-快速采集--(杨正汉)
  • (pytorch进阶之路)CLIP模型 实现图像多模态检索任务
  • (void) (_x == _y)的作用
  • (二)linux使用docker容器运行mysql
  • (个人笔记质量不佳)SQL 左连接、右连接、内连接的区别
  • (十八)三元表达式和列表解析
  • (算法)求1到1亿间的质数或素数
  • (一)Neo4j下载安装以及初次使用
  • (转)关于多人操作数据的处理策略
  • (转载)微软数据挖掘算法:Microsoft 时序算法(5)
  • ***详解账号泄露:全球约1亿用户已泄露
  • .NET实现之(自动更新)
  • .net项目IIS、VS 附加进程调试
  • .考试倒计时43天!来提分啦!
  • //解决validator验证插件多个name相同只验证第一的问题
  • /3GB和/USERVA开关
  • @font-face 用字体画图标
  • @Not - Empty-Null-Blank