J48 C4.5决策树算法源码学习
TODO: J48 的分类效率分析。
一、 准备工作。

二、代码流解析:
模型的学习程序从 J48.java 开始。
J48.buildClassifier(ins): 选取 C45 决策树模型为例:
modSelection = new C45ModelSelection(m_minNumObj, instances);
m_root = new C45PruneableClassifierTree(modSelection, !m_unpruned, m_CF,
m_subtreeRaising, !m_noCleanup);
m_root.buildClassifier(instances);
将C45Pruneable*.buildClassifier(ins) 继续展开:
对基类ClassifierTree. buildTree()继续展开:data.deleteWithMissingClass(); buildTree(data, m_subtreeRaising || !m_cleanup); collapse(); if (m_pruneTheTree) { prune(); } if (m_cleanup) { cleanup(new Instances(data, 0)); }
调用 modSelection.selectModel(ins);
modSelection.split(ins). // 分割数据
m_sons[i] = getNewTree(localInstances[i]); // 构建子树
将 C45ModelSelection.selectModel(ins) 继续展开:
将C45Split.buildClassifier() 方法继续展开:if (Utils.sm(checkDistribution.total(), 2 * m_minNoObj) || Utils.eq(checkDistribution.total(), checkDistribution.perClass(checkDistribution.maxClass()))) return noSplitModel; multiValue = !(attribute.isNumeric() || attribute.numValues() < (0.3 * m_allData.numInstances())); currentModel = new C45Split[data.numAttributes()]; sumOfWeights = data.sumOfWeights(); // For each attribute. for (i = 0; i < data.numAttributes(); i++) { // Apart from class attribute. if (i != (data).classIndex()) { // Get models for current attribute. currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights); currentModel[i].buildClassifier(data); // ... 省略代码部分: 更新 averageInfoGain的总和 } else currentModel[i] = null; } averageInfoGain = averageInfoGain / (double) validModels; // Find "best" attribute to split on. minResult = 0; for (i = 0; i < data.numAttributes(); i++) { // Use 1E-3 here to get a closer approximation to the original implementation. if ((currentModel[i].infoGain() >= (averageInfoGain - 1E-3)) && Utils.gr(currentModel[i].gainRatio(), minResult)) { bestModel = currentModel[i]; minResult = currentModel[i].gainRatio(); } } // ... 省略代码: 将全局数据加载到bestModel. Set the split point analogue to C45 if attribute numeric. if (m_allData != null) bestModel.setSplitPoint(m_allData); return bestModel;
按属性为 Enum or Numeric 类型,分别调用handleEnumerateAttribute(ins) or handleNumericAttribute(ins). 这两个方法的具体实现参考下文的讲解。
collapse() 方法: 对比合并树 前后的错误率,不增加则合并。并递归处理孩子节点。
prune() 方法:对C4.5决策树进行减枝。递归地先从孩子节点开始。 当前节点处理的关键代码:
// Compute error if this Tree would be leaf errorsLeaf = getEstimatedErrorsForDistribution(localModel() .distribution()); // Compute error for the whole subtree errorsTree = getEstimatedErrors(); // Decide if leaf is best choice. if (Utils.smOrEq(errorsLeaf, errorsTree + 0.1) && Utils.smOrEq(errorsLeaf, errorsLargestBranch + 0.1)) { // ... 将树设为叶子 } // Decide if largest branch is better choice than whole subtree. if (Utils.smOrEq(errorsLargestBranch, errorsTree + 0.1)) { largestBranch = son(indexOfLargestBranch); m_sons = largestBranch.m_sons; m_localModel = largestBranch.localModel(); m_isLeaf = largestBranch.m_isLeaf; newDistribution(m_train); prune(); }
J48.classifyInstance(data): 使用模型对 测试数据分类。
m_root.classifyInstance(data);
实际调用 ClassifyTree.classifyInstance(ins):
调用getProbs(int classIndex, data, 1)计算测试ins在各个类别下的概率,去最大类别的idx返回。
其中, 有对 ClassSplitModel.classProb() 的调用,策略与上类似,最终概率有 Distribution.prob() 得到。if (m_isLeaf) { return weight * localModel().classProb(classIndex, instance, -1); } else { int treeIndex = localModel().whichSubset(instance); if (treeIndex == -1) { // contained by multi-classLabel. double[] weights = localModel().weights(instance); for (int i = 0; i < m_sons.length; i++) { if (!son(i).m_isEmpty) { prob += son(i).getProbs(classIndex, instance, weights[i] * weight); } } return prob; } else { if (son(treeIndex).m_isEmpty) { // leaf node. return weight * localModel().classProb(classIndex, instance, treeIndex); } else { return son(treeIndex).getProbs(classIndex, instance, weight); } } }classProb() 的代码如下:
if (theSubset > -1) { return m_distribution.prob(classIndex, theSubset); } else { double[] weights = weights(instance); if (weights == null) { return m_distribution.prob(classIndex); } else { double prob = 0; for (int i = 0; i < weights.length; i++) { prob += weights[i] * m_distribution.prob(classIndex, i); } return prob; } }
三、相关代码分析。
J48.java:
m_root: 根节点。 以及一些模型学习的参数。buildClassifier(Ins)classifyInstance(ins)
ClassifierTree.java:
m_son.m_toSelectModel.ClassifierSplitModel.Instances m_train_data.Distribution m_dis.buildTree(ins)classifyInstance(ins)distributionForInstance(Instance, boolean): 计算实例在各个类别下的概率(分布)getProbsLaplace(classIdx, ins, weight): 计算 ins 的class probabilities
getProbs(...) 同上
buildClassifier(ins)
buildTree(ins, boolean)collapse() Collapses a tree to a node if training error doesn't increase: 将当前错误率(叶子节点之和)与整个distribution的错误率对比
prune()
getEstimatedErrorsForDistribution(Distribution): 估计叶子的错误率:训练错误率+估计值: dis.numIncorrect() + Stats.addErrs(dis.total(), dis.numIncorrect(), m_CF).
doGrafting(ins)
sortInstances(ins, iindex[][], limits[][], subset):
findGraft(ins, iindex[][], limits[][], ClassifierTree parent, pLaplace, pLeafClass)
prune().collapse()
Stats.java:
对给定数据集估计额外错误estimated extra error, 使用正态normal 分布近视二项binomial 分布. 主要的方法:
addErrs(double N, double e, float CF):
ModelSelection.java:
决策树的特征选择与切分抽象类
子类:
C45ModelSelection.java : 基于C45Split 的特征切分方式。主要属性和方法:
selectModel(Instances)
BinC45ModelSelection.java: 基于 BinC45Split
NBTreeModelSelection.java: 基于NBTreeSplit
ResidualModelSelection.java: 基于ResidualSplit 训练数据上的残差。
ClassifierSplitModel.java:
决策树的分支判决 抽象类。主要属性和方法:
m_numSubsets: 整个ins 被分割开的子集数。
Distribution m_dis;
classProb(cidx, ins, bagIdx) && classProbLaplace(): 基于m_dis返回 单分布 或 多分布(求和)下的概率。
split(Instances): 对数据进行分割。基于 抽象方法 weight(ins) && whichSubset(ins). 部分源码:
weights = weights(instance);
subset = whichSubset(instance);
if (subset > -1)
instances[subset].add(instance);
else
for (j = 0; j < m_numSubsets; j++)
if (Utils.gr(weights[j], 0)) {
newWeight = weights[j] * instance.weight();
instances[j].add(instance);
instances[j].lastInstance().setWeight(newWeight);
}子类:
NoSplit.java: 对数据集无分割。关键方法: weight(): @ret null; leftSide(): @ret ""; rightSide(): @ret "";
C45Split.java: 对数据集按C4.5算法分割
InfoGainSplitCrit infoGainCrit; @static
GainRatioSplitCrit gainRatioCrit; @static
m_splitPoint; // 连续属性的分割点
m_attIndex; // Get index of attribute to split on.
m_minNoObj; // 最小叶子节点数*2
m_sumOfWeights; // Set the sum of the weights
buildClassifier(ins): 调用 handleEnumeratedAttribute(ins) or handleNumericAttribute(ins) 对数据分割。
handleEnumeratedAttribute(ins): 对某个离散属性,生成m_distribution, 计算 IG 与 GR 值。
if (m_distribution.check(m_minNoObj)) { m_numSubsets = m_complexityIndex; m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights); m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain); }
handleNumericAttribute(ins): 对某个数值属性计算 IG 与 GR 与 m_splitPoint。代码逻辑:
对 m_distribution 赋值。并按 属性的值排序。
确定 minsplit = min(max(0.1*m_dis*_total()/ins.numClasses(), m_minNoObj), 25).
遍历属性, 如果 val(i-1) + 1e-5 < val(i) : 计算按此切分时的IG,并更新全局最大IG 与 分割点下标。
m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights); // 修正IG
m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex) + trainInstances.instance(splitIndex).value(m_attIndex)) / 2;
通过 gainRatioCrit.splitCritValue(Distribution, double, double numerator)计算GR
weights(ins): 处理missing的属性值
whichSubset(ins): 对missing的值返回-1, 枚举值返回idx, 否则按m_splitPoint分割为0|1.if (instance.isMissing(m_attIndex)) { weights = new double[m_numSubsets]; for (i = 0; i < m_numSubsets; i++) weights[i] = m_distribution.perBag(i) / m_distribution.total(); return weights; } else { return null; }
BinC45Split.java: 同上,不同点: m_numSubsets = 2;
GraftSplit.java:Class implementing a split for nodes added to a tree during grafting. 跳过。
NBTreeNoSplit.java: Naive Bayes Tree的no-split 对象。主要属性和方法:
NaiveBayesUpdateable m_nb;
Discretize m_disc;
m_errors;
buildClassifier(ins):
crossValidate(m_nb, ins, Random(1)); // 对模型计算CV-5错误率
NBTreeSplit.java: implementing a NBTree split on an attribute. 主要属性和方法:
m_splitPoint;
C45Split m_c45S;
NBTreeNoSplit m_globalNB;
handleEnumeratedAttribute(ins): 训练 m_c45S, 基于CV计算 m_errors, 更新m_numSubsets.
对某个ins, 决定它在树中的分支: m_C45S 能决定则使用,否则将ins分配到每个分支中。
对各个分支,节点数>5 是基于 NaiveBayesUpdateable 分类器的CV-5计算错误率,否则直接将 weight 相加做为分支的error.
handleNumericAttribute(ins): 同上,处理数值型属性。
ResidualSplit.java: Helper class for logistic model trees, implement the splitting criterion based on residuals of the LogitBoost algorithm. 具体查看wiki.
EntropyBasedSplitCrit.java:
决策树分裂准则:基于熵。主要方法:(个人不明白为啥没有log前的概率P(i) 求达人解释!)
public final double oldEnt(Distribution bags) { // Computes entropy of distribution before splitting. double returnValue = 0; int j; for (j = 0; j < bags.numClasses(); j++) returnValue = returnValue + logFunc(bags.perClass(j));// log(n)/log(2) return logFunc(bags.total()) - returnValue; } /** * Computes entropy of distribution after splitting. */ public final double newEnt(Distribution bags) { double returnValue = 0; int i, j; for (i = 0; i < bags.numBags(); i++) { for (j = 0; j < bags.numClasses(); j++) returnValue = returnValue + logFunc(bags.perClassPerBag(i, j)); returnValue = returnValue - logFunc(bags.perBag(i)); } return -returnValue; } /** * Computes entropy after splitting without considering the class values. */ public final double splitEnt(Distribution bags) { double returnValue = 0; int i; for (i = 0; i < bags.numBags(); i++) returnValue = returnValue + logFunc(bags.perBag(i)); return logFunc(bags.total()) - returnValue; }
子类:
InfoGainSplitCrit.java: computing the information gain for a given distribution. 主要方法:
splitCritValue(Distribution): 直接计算IG: bags.total()/[oldEnt(bags)-newEnt(bags)]; //取倒数是为了min splitting criterion's value.
splitCritValue(Distribution bags, double totalNoInst): 实现C4.5 的IG
double numerator; double noUnknown; double unknownRate; int i; noUnknown = totalNoInst-bags.total(); // missing Ins's weight-sum. unknownRate = noUnknown/totalNoInst; numerator = (oldEnt(bags)-newEnt(bags)); numerator = (1-unknownRate)*numerator; // bags.total()/totalNoInst * numerator // Splits with no gain are useless. if (Utils.eq(numerator,0)) return 0; return numerator/bags.total(); // oldEnt-newEnt/totalNoInst.
GainRatioSplitCrit.java: computing the gain ratio for a given distribution. 主要方法:
splitCritValue(dis): 计算 GR ,用到了splitEnt()计算分母
splitEnt(Distribution bags, double totalnoInst): 计算基于属性的split entropy. 关键代码:
noUnknown = totalnoInst - bags.total(); if (Utils.gr(bags.total(), 0)) { for (i = 0; i < bags.numBags(); i++) returnValue = returnValue - logFunc(bags.perBag(i)); returnValue = returnValue - logFunc(noUnknown); returnValue = returnValue + logFunc(totalnoInst); } return returnValue;
Distribution.java
存储instance的权重分布,分 bag(attr-num), class(class-num), class +bag, total维度
add(bagIdx)
prob(ins); laplaceProb(int): 平滑前、后的概率 Pc + 1 / total + Nc
shift(int, int, Instance): 移动ins到新的bag
subtract(Distribution): 两个dis中class 部分的weight 相减。
本文详细解析了J48(C4.5)决策树算法的源码,从J48.java的buildClassifier方法开始,探讨了数据分割、子树构建、特征选择与切分策略,如C45Split和BinC45Split。同时,介绍了ClassifierTree、ModelSelection、ClassifierSplitModel等相关类的功能和方法,包括错误率计算、概率估计等关键步骤。
2423

被折叠的 条评论
为什么被折叠?



