代码下载:
http://weka.wikispaces.com/Subversion
Use WEKA in your Java Code:
http://www.cs.umb.edu/~ding/history/310_fall_2010/homework/UseWEKA_In_Java_Code.pdf
J48 C4.5决策树算法源码学习
TODO: J48 的分类效率分析。
一、 准备工作。

二、代码流解析:
模型的学习程序从 J48.java 开始。
J48.buildClassifier(ins): 选取 C45 决策树模型为例:
- modSelection = new C45ModelSelection(m_minNumObj, instances);
- m_root = new C45PruneableClassifierTree(modSelection, !m_unpruned, m_CF,
- m_subtreeRaising, !m_noCleanup);
- m_root.buildClassifier(instances);
将C45Pruneable*.buildClassifier(ins) 继续展开:
对基类ClassifierTree. buildTree()继续展开:
- data.deleteWithMissingClass();
- buildTree(data, m_subtreeRaising || !m_cleanup);
- collapse();
- if (m_pruneTheTree) {
- prune();
- }
- if (m_cleanup) {
- cleanup(new Instances(data, 0));
- }
调用 modSelection.selectModel(ins);
modSelection.split(ins). // 分割数据
m_sons[i] = getNewTree(localInstances[i]); // 构建子树
将 C45ModelSelection.selectModel(ins) 继续展开:
将C45Split.buildClassifier() 方法继续展开:
- if (Utils.sm(checkDistribution.total(), 2 * m_minNoObj)
- || Utils.eq(checkDistribution.total(), checkDistribution.perClass(checkDistribution.maxClass())))
- return noSplitModel;
- multiValue = !(attribute.isNumeric() || attribute.numValues() < (0.3 * m_allData.numInstances()));
- currentModel = new C45Split[data.numAttributes()];
- sumOfWeights = data.sumOfWeights();
- // For each attribute.
- for (i = 0; i < data.numAttributes(); i++) {
- // Apart from class attribute.
- if (i != (data).classIndex()) {
- // Get models for current attribute.
- currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);
- currentModel[i].buildClassifier(data);
- // ... 省略代码部分: 更新 averageInfoGain的总和
- } else
- currentModel[i] = null;
- }
- averageInfoGain = averageInfoGain / (double) validModels;
- // Find "best" attribute to split on.
- minResult = 0;
- for (i = 0; i < data.numAttributes(); i++) {
- // Use 1E-3 here to get a closer approximation to the original implementation.
- if ((currentModel[i].infoGain() >= (averageInfoGain - 1E-3))
- && Utils.gr(currentModel[i].gainRatio(), minResult)) {
- bestModel = currentModel[i];
- minResult = currentModel[i].gainRatio();
- }
- }
- // ... 省略代码: 将全局数据加载到bestModel. Set the split point analogue to C45 if attribute numeric.
- if (m_allData != null)
- bestModel.setSplitPoint(m_allData);
- return bestModel;
按属性为 Enum or Numeric 类型,分别调用handleEnumerateAttribute(ins) or handleNumericAttribute(ins). 这两个方法的具体实现参考下文的讲解。
collapse() 方法: 对比合并树 前后的错误率,不增加则合并。并递归处理孩子节点。
prune() 方法:对C4.5决策树进行减枝。递归地先从孩子节点开始。 当前节点处理的关键代码:
- // Compute error if this Tree would be leaf
- errorsLeaf = getEstimatedErrorsForDistribution(localModel()
- .distribution());
- // Compute error for the whole subtree
- errorsTree = getEstimatedErrors();
- // Decide if leaf is best choice.
- if (Utils.smOrEq(errorsLeaf, errorsTree + 0.1)
- && Utils.smOrEq(errorsLeaf, errorsLargestBranch + 0.1)) {
- // ... 将树设为叶子
- }
- // Decide if largest branch is better choice than whole subtree.
- if (Utils.smOrEq(errorsLargestBranch, errorsTree + 0.1)) {
- largestBranch = son(indexOfLargestBranch);
- m_sons = largestBranch.m_sons;
- m_localModel = largestBranch.localModel();
- m_isLeaf = largestBranch.m_isLeaf;
- newDistribution(m_train);
- prune();
- }
J48.classifyInstance(data): 使用模型对 测试数据分类。
m_root.classifyInstance(data);
实际调用 ClassifyTree.classifyInstance(ins):
调用getProbs(int classIndex, data, 1)计算测试ins在各个类别下的概率,去最大类别的idx返回。
其中, 有对 ClassSplitModel.classProb() 的调用,策略与上类似,最终概率有 Distribution.prob() 得到。
- if (m_isLeaf) {
- return weight * localModel().classProb(classIndex, instance, -1);
- } else {
- int treeIndex = localModel().whichSubset(instance);
- if (treeIndex == -1) { // contained by multi-classLabel.
- double[] weights = localModel().weights(instance);
- for (int i = 0; i < m_sons.length; i++) {
- if (!son(i).m_isEmpty) {
- prob += son(i).getProbs(classIndex, instance,
- weights[i] * weight);
- }
- }
- return prob;
- } else {
- if (son(treeIndex).m_isEmpty) { // leaf node.
- return weight * localModel().classProb(classIndex, instance, treeIndex);
- } else {
- return son(treeIndex).getProbs(classIndex, instance, weight);
- }
- }
- }
classProb() 的代码如下:
- if (theSubset > -1) {
- return m_distribution.prob(classIndex, theSubset);
- } else {
- double[] weights = weights(instance);
- if (weights == null) {
- return m_distribution.prob(classIndex);
- } else {
- double prob = 0;
- for (int i = 0; i < weights.length; i++) {
- prob += weights[i] * m_distribution.prob(classIndex, i);
- }
- return prob;
- }
- }
三、相关代码分析。
J48.java:
m_root: 根节点。 以及一些模型学习的参数。buildClassifier(Ins)classifyInstance(ins)
ClassifierTree.java:
m_son.m_toSelectModel.ClassifierSplitModel.Instances m_train_data.Distribution m_dis.buildTree(ins)classifyInstance(ins)distributionForInstance(Instance, boolean): 计算实例在各个类别下的概率(分布)getProbsLaplace(classIdx, ins, weight): 计算 ins 的class probabilities
getProbs(...) 同上
buildClassifier(ins)
buildTree(ins, boolean)collapse() Collapses a tree to a node if training error doesn't increase: 将当前错误率(叶子节点之和)与整个distribution的错误率对比
prune()
getEstimatedErrorsForDistribution(Distribution): 估计叶子的错误率:训练错误率+估计值: dis.numIncorrect() + Stats.addErrs(dis.total(), dis.numIncorrect(), m_CF).
doGrafting(ins)
sortInstances(ins, iindex[][], limits[][], subset):
findGraft(ins, iindex[][], limits[][], ClassifierTree parent, pLaplace, pLeafClass)
prune().collapse()
Stats.java:
对给定数据集估计额外错误estimated extra error, 使用正态normal 分布近视二项binomial 分布. 主要的方法:
addErrs(double N, double e, float CF):
ModelSelection.java:
决策树的特征选择与切分抽象类
子类:
C45ModelSelection.java : 基于C45Split 的特征切分方式。主要属性和方法:
selectModel(Instances)
BinC45ModelSelection.java: 基于 BinC45Split
NBTreeModelSelection.java: 基于NBTreeSplit
ResidualModelSelection.java: 基于ResidualSplit 训练数据上的残差。
ClassifierSplitModel.java:
决策树的分支判决 抽象类。主要属性和方法:
m_numSubsets: 整个ins 被分割开的子集数。
Distribution m_dis;
classProb(cidx, ins, bagIdx) && classProbLaplace(): 基于m_dis返回 单分布 或 多分布(求和)下的概率。
split(Instances): 对数据进行分割。基于 抽象方法 weight(ins) && whichSubset(ins). 部分源码:
- weights = weights(instance);
- subset = whichSubset(instance);
- if (subset > -1)
- instances[subset].add(instance);
- else
- for (j = 0; j < m_numSubsets; j++)
- if (Utils.gr(weights[j], 0)) {
- newWeight = weights[j] * instance.weight();
- instances[j].add(instance);
- instances[j].lastInstance().setWeight(newWeight);
- }
NoSplit.java: 对数据集无分割。关键方法: weight(): @ret null; leftSide(): @ret ""; rightSide(): @ret "";
C45Split.java: 对数据集按C4.5算法分割
InfoGainSplitCrit infoGainCrit; @static
GainRatioSplitCrit gainRatioCrit; @static
m_splitPoint; // 连续属性的分割点
m_attIndex; // Get index of attribute to split on.
m_minNoObj; // 最小叶子节点数*2
m_sumOfWeights; // Set the sum of the weights
buildClassifier(ins): 调用 handleEnumeratedAttribute(ins) or handleNumericAttribute(ins) 对数据分割。
handleEnumeratedAttribute(ins): 对某个离散属性,生成m_distribution, 计算 IG 与 GR 值。
- if (m_distribution.check(m_minNoObj)) {
- m_numSubsets = m_complexityIndex;
- m_infoGain = infoGainCrit.splitCritValue(m_distribution,
- m_sumOfWeights);
- m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
- m_sumOfWeights, m_infoGain);
- }
handleNumericAttribute(ins): 对某个数值属性计算 IG 与 GR 与 m_splitPoint。代码逻辑:
对 m_distribution 赋值。并按 属性的值排序。
确定 minsplit = min(max(0.1*m_dis*_total()/ins.numClasses(), m_minNoObj), 25).
遍历属性, 如果 val(i-1) + 1e-5 < val(i) : 计算按此切分时的IG,并更新全局最大IG 与 分割点下标。
m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights); // 修正IG
m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex) + trainInstances.instance(splitIndex).value(m_attIndex)) / 2;
通过 gainRatioCrit.splitCritValue(Distribution, double, double numerator)计算GR
weights(ins): 处理missing的属性值
whichSubset(ins): 对missing的值返回-1, 枚举值返回idx, 否则按m_splitPoint分割为0|1.
- if (instance.isMissing(m_attIndex)) {
- weights = new double[m_numSubsets];
- for (i = 0; i < m_numSubsets; i++)
- weights[i] = m_distribution.perBag(i) / m_distribution.total();
- return weights;
- } else {
- return null;
- }
BinC45Split.java: 同上,不同点: m_numSubsets = 2;
GraftSplit.java:Class implementing a split for nodes added to a tree during grafting. 跳过。
NBTreeNoSplit.java: Naive Bayes Tree的no-split 对象。主要属性和方法:
NaiveBayesUpdateable m_nb;
Discretize m_disc;
m_errors;
buildClassifier(ins):
crossValidate(m_nb, ins, Random(1)); // 对模型计算CV-5错误率
NBTreeSplit.java: implementing a NBTree split on an attribute. 主要属性和方法:
m_splitPoint;
C45Split m_c45S;
NBTreeNoSplit m_globalNB;
handleEnumeratedAttribute(ins): 训练 m_c45S, 基于CV计算 m_errors, 更新m_numSubsets.
对某个ins, 决定它在树中的分支: m_C45S 能决定则使用,否则将ins分配到每个分支中。
对各个分支,节点数>5 是基于 NaiveBayesUpdateable 分类器的CV-5计算错误率,否则直接将 weight 相加做为分支的error.
handleNumericAttribute(ins): 同上,处理数值型属性。
ResidualSplit.java: Helper class for logistic model trees, implement the splitting criterion based on residuals of the LogitBoost algorithm. 具体查看wiki.
EntropyBasedSplitCrit.java:
决策树分裂准则:基于熵。主要方法:(个人不明白为啥没有log前的概率P(i) 求达人解释!)
- public final double oldEnt(Distribution bags) { // Computes entropy of distribution before splitting.
- double returnValue = 0;
- int j;
- for (j = 0; j < bags.numClasses(); j++)
- returnValue = returnValue + logFunc(bags.perClass(j));// log(n)/log(2)
- return logFunc(bags.total()) - returnValue;
- }
- /**
- * Computes entropy of distribution after splitting.
- */
- public final double newEnt(Distribution bags) {
- double returnValue = 0;
- int i, j;
- for (i = 0; i < bags.numBags(); i++) {
- for (j = 0; j < bags.numClasses(); j++)
- returnValue = returnValue + logFunc(bags.perClassPerBag(i, j));
- returnValue = returnValue - logFunc(bags.perBag(i));
- }
- return -returnValue;
- }
- /**
- * Computes entropy after splitting without considering the class values.
- */
- public final double splitEnt(Distribution bags) {
- double returnValue = 0;
- int i;
- for (i = 0; i < bags.numBags(); i++)
- returnValue = returnValue + logFunc(bags.perBag(i));
- return logFunc(bags.total()) - returnValue;
- }
子类:
InfoGainSplitCrit.java: computing the information gain for a given distribution. 主要方法:
splitCritValue(Distribution): 直接计算IG: bags.total()/[oldEnt(bags)-newEnt(bags)]; //取倒数是为了min splitting criterion's value.
splitCritValue(Distribution bags, double totalNoInst): 实现C4.5 的IG
- double numerator;
- double noUnknown;
- double unknownRate;
- int i;
- noUnknown = totalNoInst-bags.total(); // missing Ins's weight-sum.
- unknownRate = noUnknown/totalNoInst;
- numerator = (oldEnt(bags)-newEnt(bags));
- numerator = (1-unknownRate)*numerator; // bags.total()/totalNoInst * numerator
- // Splits with no gain are useless.
- if (Utils.eq(numerator,0))
- return 0;
- return numerator/bags.total(); // oldEnt-newEnt/totalNoInst.
GainRatioSplitCrit.java: computing the gain ratio for a given distribution. 主要方法:
splitCritValue(dis): 计算 GR ,用到了splitEnt()计算分母
splitEnt(Distribution bags, double totalnoInst): 计算基于属性的split entropy. 关键代码:
- noUnknown = totalnoInst - bags.total();
- if (Utils.gr(bags.total(), 0)) {
- for (i = 0; i < bags.numBags(); i++)
- returnValue = returnValue - logFunc(bags.perBag(i));
- returnValue = returnValue - logFunc(noUnknown);
- returnValue = returnValue + logFunc(totalnoInst);
- }
- return returnValue;
Distribution.java
存储instance的权重分布,分 bag(attr-num), class(class-num), class +bag, total维度
add(bagIdx)
prob(ins); laplaceProb(int): 平滑前、后的概率 Pc + 1 / total + Nc
shift(int, int, Instance): 移动ins到新的bag
subtract(Distribution): 两个dis中class 部分的weight 相减。
参考资料:
***********************************************************************************************************************
Weka学习五(ROC简介)
2009-11-16 10:45
Weka学习四(属性选择)
2009-11-16 10:45
Weka学习三(ensemble算法)
2009-11-16 10:44
Weka学习二(聚类算法)
2009-11-16 10:44
Weka学习一(分类器算法)
2009-11-16 10:43
利用WEKA编写数据挖掘算法
2009-11-16 10:17
Weka开发[-1]——在你的代码中使用Weka
2009-11-15 21:08
本文详细介绍了Weka中J48决策树算法(C4.5)的源码学习,涵盖了从J48.java开始的代码流程,包括模型学习、数据分割、C45Split的选择与评估,以及分类实例的处理。文章通过代码片段展示了决策树构建的过程,如C45ModelSelection、ClassifierSplitModel和C45Split等关键类的功能。同时,文章还提及了与之相关的熵基分裂准则和信息增益计算。
3452

被折叠的 条评论
为什么被折叠?



