当前位置: 首页 > news >正文

数据挖掘之决策树ID3算法(C#实现)

决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏。比如猜一个动物:

问:这个动物是陆生动物吗?

答:是的。

问:这个动物有鳃吗?

答:没有。

这样的两个问题顺序就有些颠倒,因为一般来说陆生动物是没有鳃的(记得应该是这样的,如有错误欢迎指正)。所以玩这种游戏,提问的顺序很重要,争取每次都能够获得尽可能多的信息量。

AllElectronics顾客数据库标记类的训练元组
RIDageincomestudentcredit_ratingClass: buys_computer
1youthhighnofairno
2youthhighnoexcellentno
3middle_agedhighnofairyes
4seniormediumnofairyes
5seniorlowyesfairyes
6seniorlowyesexcellentno
7middle_agedlowyesexcellentyes
8youthmediumnofairno
9youthlowyesfairyes
10seniormediumyesfairyes
11youthmediumyesexcellentyes
12middle_agedmediumnoexcellentyes
13middle_agedhighyesfairyes
14seniormediumnoexcellentno

AllElectronics顾客数据库标记类的训练元组为例。我们想要以这些样本为训练集,训练我们的决策树模型,以此来挖掘出顾客是否会购买电脑的决策模式。

在决策树ID3算法中,计算信息度的公式如下:

$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$

计算信息增益的公式如下:

$$Gain(A) = Info(D) - Info_A(D)$$

按照公式,在要进行分类的类别变量中,有5个“no”9个“yes”,因此期望信息为:

$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$

首先计算特征age的期望信息:

$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$

因此,如果按照age进行划分,则获得的信息增益为:

$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$

依次计算以incomestudentcredit_rating来分裂的信息增益,由此选择能够带来最大信息增益的变量,在当

前结点选择以以该变量的取值进行分裂。递归地进行执行即可生成决策树。更加详细的内容可以参考:

https://en.wikipedia.org/wiki/Decision_tree

C#代码的实现如下:

  1 using System;
  2 using System.Collections.Generic;
  3 using System.Linq;
  4 namespace MachineLearning.DecisionTree
  5 {
  6     public class DecisionTreeID3<T> where T : IEquatable<T>
  7     {
  8         T[,] Data;
  9         string[] Names;
 10         int Category;
 11         T[] CategoryLabels;
 12         DecisionTreeNode<T> Root;
 13         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
 14         {
 15             Data = data;
 16             Names = names;
 17             Category = data.GetLength(1) - 1;//类别变量需要放在最后一列
 18             CategoryLabels = categoryLabels;
 19         }
 20         public void Learn()
 21         {
 22             int nRows = Data.GetLength(0);
 23             int nCols = Data.GetLength(1);
 24             int[] rows = new int[nRows];
 25             int[] cols = new int[nCols];
 26             for (int i = 0; i < nRows; i++) rows[i] = i;
 27             for (int i = 0; i < nCols; i++) cols[i] = i;
 28             Root = new DecisionTreeNode<T>(-1, default(T));
 29             Learn(rows, cols, Root);
 30             DisplayNode(Root);
 31         }
 32         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
 33         {
 34             if (Node.Label != -1)
 35                 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
 36             foreach (var item in Node.Children)
 37                 DisplayNode(item, depth + 1);
 38         }
 39         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root)
 40         {
 41             var categoryValues = GetAttribute(Data, Category, pnRows);
 42             var categoryCount = categoryValues.Distinct().Count();
 43             if (categoryCount == 1)
 44             {
 45                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
 46                 Root.Children.Add(node);
 47             }
 48             else
 49             {
 50                 if (pnRows.Length == 0) return;
 51                 else if (pnCols.Length == 1)
 52                 {
 53                     //投票~
 54                     //多数票表决制
 55                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
 56                     var node = new DecisionTreeNode<T>(Category, Vote.First());
 57                     Root.Children.Add(node);
 58                 }
 59                 else
 60                 {
 61                     var maxCol = MaxEntropy(pnRows, pnCols);
 62                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
 63                     string currentPrefix = Names[maxCol];
 64                     foreach (var attr in attributes)
 65                     {
 66                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
 67                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
 68                         var node = new DecisionTreeNode<T>(maxCol, attr);
 69                         Root.Children.Add(node);
 70                         Learn(rows, cols, node);//递归生成决策树
 71                     }
 72                 }
 73             }
 74         }
 75         public double AttributeInfo(int attrCol, int[] pnRows)
 76         {
 77             var tuples = AttributeCount(attrCol, pnRows);
 78             var sum = (double)pnRows.Length;
 79             double Entropy = 0.0;
 80             foreach (var tuple in tuples)
 81             {
 82                 int[] count = new int[CategoryLabels.Length];
 83                 foreach (var irow in pnRows)
 84                     if (Data[irow, attrCol].Equals(tuple.Item1))
 85                     {
 86                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
 87                         count[index]++;//目前仅支持类别变量在最后一列
 88                     }
 89                 double k = 0.0;
 90                 for (int i = 0; i < count.Length; i++)
 91                 {
 92                     double frequency = count[i] / (double)tuple.Item2;
 93                     double t = -frequency * Log2(frequency);
 94                     k += t;
 95                 }
 96                 double freq = tuple.Item2 / sum;
 97                 Entropy += freq * k;
 98             }
 99             return Entropy;
100         }
101         public double CategoryInfo(int[] pnRows)
102         {
103             var tuples = AttributeCount(Category, pnRows);
104             var sum = (double)pnRows.Length;
105             double Entropy = 0.0;
106             foreach (var tuple in tuples)
107             {
108                 double frequency = tuple.Item2 / sum;
109                 double t = -frequency * Log2(frequency);
110                 Entropy += t;
111             }
112             return Entropy;
113         }
114         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
115         {
116             foreach (var irow in pnRows)
117                 yield return data[irow, col];
118         }
119         private static double Log2(double x)
120         {
121             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
122         }
123         public int MaxEntropy(int[] pnRows, int[] pnCols)
124         {
125             double cateEntropy = CategoryInfo(pnRows);
126             int maxAttr = 0;
127             double max = double.MinValue;
128             foreach (var icol in pnCols)
129                 if (icol != Category)
130                 {
131                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
132                     if (max < Gain)
133                     {
134                         max = Gain;
135                         maxAttr = icol;
136                     }
137                 }
138             return maxAttr;
139         }
140         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
141         {
142             var tuples = from n in GetAttribute(Data, col, pnRows)
143                          group n by n into i
144                          select Tuple.Create(i.First(), i.Count());
145             return tuples;
146         }
147     }
148 }

决策树结点的构造:

 1 using System.Collections.Generic;
 2 
 3 namespace MachineLearning.DecisionTree
 4 {
 5     public sealed class DecisionTreeNode<T>
 6     {
 7         public int Label { get; set; }
 8         public T Value { get; set; }
 9         public List<DecisionTreeNode<T>> Children { get; set; }
10         public DecisionTreeNode(int label, T value)
11         {
12             Label = label;
13             Value = value;
14             Children = new List<DecisionTreeNode<T>>();
15         }
16     }
17 }

 

调用方法如下:

 1 using System;
 2 using System.Collections.Generic;
 3 using System.Linq;
 4 using System.Text;
 5 using System.Threading.Tasks;
 6 using MachineLearning.DecisionTree;
 7 namespace MachineLearning
 8 {
 9     class Program
10     {
11         static void Main(string[] args)
12         {
13             var da = new string[,]
14             {
15                 {"youth","high","no","fair","no"},
16                 {"youth","high","no","excellent","no"},
17                 {"middle_aged","high","no","fair","yes"},
18                 {"senior","medium","no","fair","yes"},
19                 {"senior","low","yes","fair","yes"},
20                 {"senior","low","yes","excellent","no"},
21                 {"middle_aged","low","yes","excellent","yes"},
22                 {"youth","medium","no","fair","no"},
23                 {"youth","low","yes","fair","yes"},
24                 {"senior","medium","yes","fair","yes"},
25                 {"youth","medium","yes","excellent","yes"},
26                 {"middle_aged","medium","no","excellent","yes"},
27                 {"middle_aged","high","yes","fair","yes"},
28                 {"senior","medium","no","excellent","no"}
29             };
30             var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
31             var tree = new DecisionTreeID3<string>(da, names, new string[] { "yes", "no" });
32             tree.Learn();
33             Console.ReadKey();
34         }
35     }
36 }

 

运行结果:


 

注:作者本人也在学习中,能力有限,如有错漏还请不吝指正。转载请注明作者。

转载于:https://www.cnblogs.com/HeYanjie/p/5787361.html

相关文章:

  • 【一小时入门】webpack 入门指南
  • Vue中 beforeRouteLeave离开路由之前要执行的操作
  • AF3.1.0简单二次封装
  • Vue 项目中 根目录中router路由拦截 beforeEach 常用的写法
  • 不同按钮模板自定义
  • react项目中没有路由守卫,需要拦截的话,只能在路径上拦截,可以自己去封装 Route
  • 跟锦数学160823-190322, 共 942 题
  • JavaScript新鲜事·第5期
  • vue 项目实战 递归
  • react 项目 tab列表 把返回的一个字段数组,全部 展示在一个字段里
  • Python语言学习 (六)1.2
  • js语法中 ?. 和 ?? 的含义以及用法说明
  • 工作中MySql常用操作
  • 命令模式-对象行为型
  • git 分支 合并 具体执行过程 详细
  • .pyc 想到的一些问题
  • 【159天】尚学堂高琪Java300集视频精华笔记(128)
  • 【JavaScript】通过闭包创建具有私有属性的实例对象
  • Android系统模拟器绘制实现概述
  • EOS是什么
  • IIS 10 PHP CGI 设置 PHP_INI_SCAN_DIR
  • Javascripit类型转换比较那点事儿,双等号(==)
  • Javascript Math对象和Date对象常用方法详解
  • JavaScript服务器推送技术之 WebSocket
  • JavaWeb(学习笔记二)
  • node入门
  • session共享问题解决方案
  • SpringCloud(第 039 篇)链接Mysql数据库,通过JpaRepository编写数据库访问
  • Vue官网教程学习过程中值得记录的一些事情
  • Vultr 教程目录
  • 分布式事物理论与实践
  • 基于游标的分页接口实现
  • 小程序 setData 学问多
  • 移动端 h5开发相关内容总结(三)
  • 怎么把视频里的音乐提取出来
  • kubernetes资源对象--ingress
  • 从如何停掉 Promise 链说起
  • 新海诚画集[秒速5センチメートル:樱花抄·春]
  • ​​​​​​​ubuntu16.04 fastreid训练过程
  • #pragma once与条件编译
  • (aiohttp-asyncio-FFmpeg-Docker-SRS)实现异步摄像头转码服务器
  • (Bean工厂的后处理器入门)学习Spring的第七天
  • (Java岗)秋招打卡!一本学历拿下美团、阿里、快手、米哈游offer
  • (poj1.3.2)1791(构造法模拟)
  • (附源码)springboot助农电商系统 毕业设计 081919
  • (六)c52学习之旅-独立按键
  • (三)c52学习之旅-点亮LED灯
  • (四)库存超卖案例实战——优化redis分布式锁
  • (轉)JSON.stringify 语法实例讲解
  • .NET Core实战项目之CMS 第十二章 开发篇-Dapper封装CURD及仓储代码生成器实现
  • .NET 中使用 TaskCompletionSource 作为线程同步互斥或异步操作的事件
  • @Autowired 与@Resource的区别
  • @DataRedisTest测试redis从未如此丝滑
  • [ vulhub漏洞复现篇 ] JBOSS AS 4.x以下反序列化远程代码执行漏洞CVE-2017-7504
  • [20170713] 无法访问SQL Server