随机森林和GBDT的学习_第1页
随机森林和GBDT的学习_第2页
随机森林和GBDT的学习_第3页
随机森林和GBDT的学习_第4页
随机森林和GBDT的学习_第5页
已阅读5页,还剩24页未读 继续免费阅读

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

1、随机森林和GBDT的学习2015-03-31      0 个评论    来源:走在前往架构师的路上  收藏    我要投稿前言提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的”树“,指的就是Decision Tree-决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradient Boosting Decision Tree

2、,就是梯度提升决策树,与随机森林的思想很像,但是比随机森林稍稍的难一点,当然效果相对于前者而言,也会好许多。由于本人才疏学浅,本文只会详细讲述Random Forest算法的部分,至于GBDT我会给出一小段篇幅做介绍引导,读者能够如果有兴趣的话,可以自行学习。随机森林算法决策树要想理解随机森林算法,就不得不提决策树,什么是决策树,如何构造决策树,简单的回答就是数据的分类以树形结构的方式所展现,每个子分支都代表着不同的分类情况,比如下面的这个图所示:当然决策树的每个节点分支不一定是三元的,可以有2个或者更多。分类的终止条件为,没有可以再拿来分类的属性条件或者说分到的数据的分类已经完全一致的情况。

3、决策树分类的标准和依据是什么呢,下面介绍主要的2种划分标准。1、信息增益。这是ID3算法系列所用的方法,C4.5算法在这上面做了少许的改进,用信息增益率来作为划分的标准,可以稍稍减小数据过于拟合的缺点。2、基尼指数。这是CART分类回归树所用的方法。也是类似于信息增益的一个定义,最终都是根据数据划分后的纯度来做比较,这个纯度,你也可以理解为熵的变化,当然我们所希望的情况就是分类后数据的纯度更纯,也就是说,前后划分分类之后的熵的差越大越好。不过CART算法比较好的一点是树构造好后,还有剪枝的操作,剪枝操作的种类就比较多了,我之前在实现CART算法时用的是代价复杂度的剪枝方法。这2种决策算法在我之

4、前的博文中已经有所提及,不理解的可以点击我的ID3系列算法介绍和我的CART分类回归树算法。Boosting原本不打算将Boosting单独拉出来讲的,后来想想还是有很多内容可谈的。Boosting本身不是一种算法,他更应该说是一种思想,首先对数据构造n个弱分类器,最后通过组合n个弱分类器对于某个数据的判断结果作为最终的分类结果,就变成了一个强分类器,效果自然要好过单一分类器的分类效果。他可以理解为是一种提升算法,举一个比较常见的Boosting思想的算法AdaBoost,他在训练每个弱分类器的时候,提高了对于之前分错数据的权重值,最终能够组成一批相互互补的分类器集合。详细可以查看我的AdaB

5、oost算法学习。OK,2个重要的概念都已经介绍完毕,终于可以介绍主角Random Forest的出现了,正如前言中所说Random Forest=Decision Trees + Boosting,这里的每个弱分类器就是一个决策树了,不过这里的决策树都是二叉树,就是只有2个孩子分支,自然我立刻想到的做法就是用CART算法来构建,因为人家算法就是二元分支的。随机算法,随机算法,当然重在随机2个字上面,下面是2个方面体现了随机性。对于数据样本的采集量,比如我数据由100条,我可以每次随机取出其中的20条,作为我构造决策树的源数据,采取又放回的方式,并不是第一次抽到的数据,第二次不能重复,第二随机

6、性体现在对于数据属性的随机采集,比如一行数据总共有10个特征属性,我每次随机采用其中的4个。正是由于对于数据的行压缩和列压缩,使得数据的随机性得以保证,就很难出现之前的数据过拟合的问题了,也就不需要在决策树最后进行剪枝操作了,这个是与一般的CART算法所不同的,尤其需要注意。下面是随机森林算法的构造过程:1、通过给定的原始数据,选出其中部分数据进行决策树的构造,数据选取是”有放回“的过程,我在这里用的是CART分类回归树。2、随机森林构造完成之后,给定一组测试数据,使得每个分类器对其结果分类进行评估,最后取评估结果的众数最为最终结果。算法非常的好理解,在Boosting算法和决策树之上做了一个

7、集成,下面给出算法的实现,很多资料上只有大篇幅的理论,我还是希望能带给大家一点实在的东西。随机算法的实现输入数据(之前决策树算法时用过的)input.txt: ?123456789101112131415Rid Age Income Student CreditRating BuysComputer1 Youth High No Fair No2 Youth High No Excellent No3 MiddleAged High No Fair Yes4 Senior Medium No Fair Yes5 Senior Low Yes Fair Yes6 Senior Low

8、Yes Excellent No7 MiddleAged Low Yes Excellent Yes8 Youth Medium No Fair No9 Youth Low Yes Fair Yes10 Senior Medium Yes Fair Yes11 Youth Medium Yes Excellent Yes12 MiddleAged Medium No Excellent Yes13 MiddleAged High Yes Fair Yes14 Senior Medium No Excellent No 树节点类TreeNode.java: ?12345678

9、910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485package DataMining_RandomForest; import java.util.ArrayList; /* * 回归分类树节点 *  * author lyq *  */public class Tree

10、Node     / 节点属性名字    private String attrName;    / 节点索引标号    private int nodeIndex;    /包含的叶子节点数    private int leafNum;    / 节点误差率    priva

11、te double alpha;    / 父亲分类属性值    private String parentAttrValue;    / 孩子节点    private TreeNode childAttrNode;    / 数据记录索引    private ArrayList<String> dataIndex;  &

12、#160;  public String getAttrName()         return attrName;         public void setAttrName(String attrName)         this.attrName = attrName;    

13、;     public int getNodeIndex()         return nodeIndex;         public void setNodeIndex(int nodeIndex)         this.nodeIndex = nodeIndex; 

14、60;       public double getAlpha()         return alpha;         public void setAlpha(double alpha)         this.alpha = alpha;  &

15、#160;      public String getParentAttrValue()         return parentAttrValue;         public void setParentAttrValue(String parentAttrValue)        &#

16、160;this.parentAttrValue = parentAttrValue;         public TreeNode getChildAttrNode()         return childAttrNode;         public void setChildAttrNode(TreeNode childAt

17、trNode)         this.childAttrNode = childAttrNode;         public ArrayList<String> getDataIndex()         return dataIndex;       &#

18、160; public void setDataIndex(ArrayList<String> dataIndex)         this.dataIndex = dataIndex;         public int getLeafNum()         return leafNum;  &#

19、160;      public void setLeafNum(int leafNum)         this.leafNum = leafNum;                   决策树类DecisionTree.java:  ?1234

20、567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371

21、38139140141142143144145146147148149150151152153154155156157158159160161162163164165package DataMining_RandomForest; import java.util.ArrayList;import java.util.HashMap;import java.util.Map; /* * 决策树 *  * author lyq *  */public class DecisionTree    &

22、#160;/ 树的根节点    TreeNode rootNode;    / 数据的属性列名称    String featureNames;    / 这棵树所包含的数据    ArrayList<String> datas;    / 决策树构造的的工具类    CARTTool tool; 

23、;    public DecisionTree(ArrayList<String> datas)         this.datas = datas;        this.featureNames = datas.get(0);         tool = new CARTTool(da

24、tas);        / 通过CART工具类进行决策树的构建,并返回树的根节点        rootNode = tool.startBuildingTree();         /*     * 根据给定的数据特征描述进行类别的判断     

25、;*      * param features     * return     */    public String decideClassType(String features)         String classType = ""     &

26、#160;  / 查询属性组        String queryFeatures;        / 在本决策树中对应的查询的属性值描述        ArrayList<String> featureStrs;         fe

27、atureStrs = new ArrayList<>();        queryFeatures = features.split(",");         String array;        for (String name : featureNames)    

28、0;        for (String featureValue : queryFeatures)                 array = featureValue.split("=");           

29、0;    / 将对应的属性值加入到列表中                if (array0.equals(name)                     featureStrs.add(arr

30、ay);                                             / 开始从根据节点往下递归搜索  

31、;      classType = recusiveSearchClassType(rootNode, featureStrs);         return classType;         /*     * 递归搜索树,查询属性的分类类别     

32、*      * param node     *            当前搜索到的节点     * param remainFeatures     *           

33、剩余未判断的属性     * return     */    private String recusiveSearchClassType(TreeNode node,            ArrayList<String> remainFeatures)      

34、60;  String classType = null;         / 如果节点包含了数据的id索引,说明已经分类到底了        if (node.getDataIndex() != null && node.getDataIndex().size() > 0)          &

35、#160;  classType = judgeClassType(node.getDataIndex();             return classType;                 / 取出剩余属性中的一个匹配属性作为当前的判断属性名称  

36、      String currentFeature = null;        for (String featureValue : remainFeatures)             if (node.getAttrName().equals(featureValue0)     

37、60;           currentFeature = featureValue;                break;                

38、60;            for (TreeNode childNode : node.getChildAttrNode()             / 寻找子节点中属于此属性值的分支            if (childNode.

39、getParentAttrValue().equals(currentFeature1)                 remainFeatures.remove(currentFeature);                classType = recusiveSearc

40、hClassType(childNode, remainFeatures);                 / 如果找到了分类结果,则直接挑出循环                break;      

41、60;     else                /进行第二种情况的判断加上!符号的情况                String value = childNode.getParentAttrValue(); 

42、                                if(value.charAt(0) = '!')            

43、60;       /去掉第一个!字符                    value = value.substring(1, value.length();              

44、                           if(!value.equals(currentFeature1)                 

45、60;      remainFeatures.remove(currentFeature);                        classType = recusiveSearchClassType(childNode, remainFeatures);  

46、0;                      break;                          

47、0;                                      return classType;         

48、;/*     * 根据得到的数据行分类进行类别的决策     *      * param dataIndex     *            根据分类的数据索引号     * return    

49、; */    public String judgeClassType(ArrayList<String> dataIndex)         / 结果类型值        String resultClassType = ""        String classTyp

50、e = ""        int count = 0;        int temp = 0;        Map<String, Integer> type2Num = new HashMap<String, Integer>();      &#

51、160;  for (String index : dataIndex)             temp = Integer.parseInt(index);            / 取最后一列的决策类别数据          

52、0; classType = datas.get(temp)featureNames.length - 1;             if (type2Num.containsKey(classType)                 / 如果类别已经存在,则使其计数加1  &

53、#160;             count = type2Num.get(classType);                count+;             else 

54、0;               count = 1;                         type2Num.put(classType, count);   

55、;              / 选出其中类别支持计数最多的一个类别值        count = -1;        for (Map.Entry entry : type2Num.entrySet()        

56、60;    if (int) entry.getValue() > count)                 count = (int) entry.getValue();                resultClassT

57、ype = (String) entry.getKey();                             return resultClassType;    随机森林算法工具类RandomForestTool.java:  ?12345

58、678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713

59、8139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223package DataMining_RandomForest; import

60、 java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.HashMap;import java.util.Map;import java.util.Random; /* * 随机森林算法工具类 *  * author lyq *  */public class RandomForestTool  

61、60;  / 测试数据文件地址    private String filePath;    / 决策树的样本占总数的占比率    private double sampleNumRatio;    / 样本数据的采集特征数量占总特征的比例    private double featureNumRatio;    / 决策树的采样样本数

62、    private int sampleNum;    / 样本数据的采集采样特征数    private int featureNum;    / 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量    private int treeNum;    / 随机数产生器    private R

63、andom random;    / 样本数据列属性名称行    private String featureNames;    / 原始的总的数据    private ArrayList<String> totalDatas;    / 决策树森林    private ArrayList<DecisionTree> decisi

64、onForest;     public RandomForestTool(String filePath, double sampleNumRatio,            double featureNumRatio)         this.filePath = filePath;     

65、;   this.sampleNumRatio = sampleNumRatio;        this.featureNumRatio = featureNumRatio;         readDataFile();         /*     *

66、从文件中读取数据     */    private void readDataFile()         File file = new File(filePath);        ArrayList<String> dataArray = new ArrayList<String>();  

67、60;      try             BufferedReader in = new BufferedReader(new FileReader(file);            String str;       

68、     String tempArray;            while (str = in.readLine() != null)                 tempArray = str.split(" ");  

69、              dataArray.add(tempArray);                        in.close();      

70、   catch (IOException e)             e.getStackTrace();                 totalDatas = dataArray;        fe

71、atureNames = totalDatas.get(0);        sampleNum = (int) (totalDatas.size() - 1) * sampleNumRatio);        /算属性数量的时候需要去掉id属性和决策属性,用条件属性计算        featureNum = (int) (featureNames.le

72、ngth -2) * featureNumRatio);        / 算数量的时候需要去掉首行属性名称行        treeNum = (totalDatas.size() - 1) / sampleNum;         /*     * 产生决策树   

73、  */    private DecisionTree produceDecisionTree()         int temp = 0;        DecisionTree tree;        String tempData;    

74、0;   /采样数据的随机行号组        ArrayList<Integer> sampleRandomNum;        /采样属性特征的随机列号组        ArrayList<Integer> featureRandomNum;     

75、;   ArrayList<String> datas;                 sampleRandomNum = new ArrayList<>();        featureRandomNum = new ArrayList<>();  

76、0;     datas = new ArrayList<>();                 for(int i=0; i<sampleNum;)            temp = random.nextInt(totalDatas.

77、size();                         /如果是行首属性名称行,则跳过            if(temp = 0)       &#

78、160;        continue;                                     if(!sampleRandomN

79、um.contains(temp)                sampleRandomNum.add(temp);                i+;          &#

80、160;                          for(int i=0; i<featureNum;)            temp = random.nextInt(featureNames.len

81、gth);                         /如果是第一列的数据id号或者是决策属性列,则跳过            if(temp = 0 | temp = featureNames.length-1) 

82、;               continue;                                 &#

83、160;   if(!featureRandomNum.contains(temp)                featureRandomNum.add(temp);                i+;   

84、60;                         String singleRecord;        String headCulumn = null;        / 获取随机

85、数据行        for(int dataIndex: sampleRandomNum)            singleRecord = totalDatas.get(dataIndex);                 

86、;        /每行的列数=所选的特征数+id号            tempData = new StringfeatureNum+2;            headCulumn = new StringfeatureNum+2;   &#

87、160;                     for(int i=0,k=1; i<featureRandomNum.size(); i+,k+)                temp = featureRandomN

88、um.get(i);                                 headCulumnk = featureNamestemp;          &

89、#160;     tempDatak = singleRecordtemp;                                     /加上id列的信息 

90、;           headCulumn0 = featureNames0;            /加上决策分类列的信息            headCulumnfeatureNum+1 = featureNamesfeatureNames.

91、length-1;            tempDatafeatureNum+1 = singleRecordfeatureNames.length-1;                         /加入此行数据 

92、;           datas.add(tempData);                         /加入行首列出现名称        datas

93、.add(0, headCulumn);        /对筛选出的数据重新做id分配        temp = 0;        for(String array: datas)            /从第2行开始赋值 &

94、#160;          if(temp > 0)                array0 = temp + ""                

95、                     temp+;                         tree = new Decisio

96、nTree(datas);                 return tree;         /*     * 构造随机森林     */    public void constructRa

97、ndomTree()         DecisionTree tree;        random = new Random();        decisionForest = new ArrayList<>();         System.out

98、.println("下面是随机森林中的决策树:");        / 构造决策树加入森林中        for (int i = 0; i < treeNum; i+)             System.out.println("n决策树" + (i+1);

99、0;           tree = produceDecisionTree();            decisionForest.add(tree);                 /*&#

100、160;    * 根据给定的属性条件进行类别的决策     *      * param features     *            给定的已知的属性描述     * return     

101、*/    public String judgeClassType(String features)         / 结果类型值        String resultClassType = ""        String classType = "" 

102、60;      int count = 0;        Map<String, Integer> type2Num = new HashMap<String, Integer>();         for (DecisionTree tree : decisionForest)             classType = tree.decideClassType(features);            if (type2Num.containsKey(classType)        

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论