C++手撕基于ID3算法的决策树

背景

        书接上回,完成了KNN之后,小H又继续学习机器学习相关内容,这一次看到的是决策树,构建一个棵树来进行分类任务,确实是非常形象呢。

概念

        决策树是一种监督学习算法,常用于分类任务。ID3 算法通过计算信息增益来选择最优特征进行分裂,最终生成一棵树状结构,内部节点表示一个特征/属性,叶子节点表示一个类别。信息增益是信息论中的一个概念,用于衡量某个特征分裂数据集前后信息的减少量。

        ID3 算法的核心是信息增益(Information Gain),即通过计算每个特征对数据集分类的 “贡献度”,选择贡献度最大的特征作为当前节点的划分依据,直至所有样本被正确分类或无法继续划分。

关键概念

  1. 信息熵(Entropy) 信息熵是衡量数据集 “混乱程度” 的指标,熵值越高,数据越混乱(分类越不明确)。 

  2. 条件熵(Conditional Entropy) 当用特征A划分数据集D时,划分后的数据子集的平均信息熵称为条件熵。

  3. 信息增益(Information Gain) 信息增益是 “原始数据集的熵” 与 “按特征A划分后的条件熵” 的差值,衡量特征A对分类的贡献。信息增益越大,说明用特征A划分后的数据 “混乱程度降低越多”,该特征越适合作为当前节点的划分依据。

(以上几种公式就可以自行搜索)

ID3 算法流程

  1. 初始化:将所有训练样本作为根节点的数据集。
  2. 终止条件判断
    • 若当前数据集所有样本属于同一类别,将该节点标记为叶节点,返回类别。
    • 若没有剩余特征可划分,将该节点标记为叶节点,返回样本中占比最高的类别(多数表决)。
  3. 选择最优特征
    • 计算当前数据集的信息熵H(D)。
    • 对每个未使用的特征A,计算其信息增益Gain(D,A)。
    • 选择信息增益最大的特征A作为当前节点的划分特征。
  4. 划分数据集
    • 按特征A的所有取值,将数据集拆分为多个子集D_j(每个取值对应一个子集)。
    • 为每个子集创建子节点,递归执行步骤 2-4,直至满足终止条件。

数据准备

        就假设我们有某个数据集,列名是:Age、income、Marital Status、Label,然后若干行吧。

数据结构

        我们将整个决策树封装到DecisionTree类中。

DecisionTree类

class DecisionTree{
public:
    // 节点类型枚举
    enum class NodeType{
        INTERNAL,
        LEAF
    };
    
    // 节点结构
    struct Node{
        NodeType type;
        std::string feature;
        std::map<std::string, std::unique_ptr<Node>> children;
        std::optional<std::string> label;
    };

    void fit(const Dataset& data, const std::vector<std::string>& features);
    std::string predict(const Example& example) const;

private:
    std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);
    double entropy(const Dataset& data) const;
    double information_gain(const Dataset& data, const std::string& feature) const;
    std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;
    std::string get_feature_value(const Example& example, const std::string& feature) const;
    
    std::unique_ptr<Node> root_;
    std::string predict_helper(const Node* node, const Example& example) const;
};

方法实现

计算熵

// 计算熵
double DecisionTree::entropy(const Dataset& data) const {
    if (data.empty()) return 0.0;

    std::unordered_map<std::string, int> label_counts;
    for (const auto& ex : data) {
        label_counts[ex.label]++;
    }

    double total = static_cast<double>(data.size());
    double entropy_value = 0.0;

    for (const auto& [label, count] : label_counts) {
        double p = static_cast<double>(count) / total;
        if (p > 0) { // 避免 log(0)
            entropy_value -= p * std::log2(p);
        }
    }
    return entropy_value;
}

计算信息增益

// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const {
    double initial_entropy = entropy(data);
    double weighted_entropy = 0.0;

    std::map<std::string, Dataset> split_data;
    for (const auto& ex : data) {
        std::string feature_value = get_feature_value(ex, feature);
        split_data[feature_value].push_back(ex);
    }

    for (const auto& [value, subset] : split_data) {
        double weight = static_cast<double>(subset.size()) / data.size();
        weighted_entropy += weight * entropy(subset);
    }

    return initial_entropy - weighted_entropy;
}

选择最佳特征

// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const {
    double max_gain = -1.0;
    std::string best_feature;

    std::cout << "  Information gains:" << std::endl;
    for (const auto& feature : features) {
        double gain = information_gain(data, feature);
        std::cout << "    " << feature << ": " << gain << std::endl;
        if (gain > max_gain) {
            max_gain = gain;
            best_feature = feature;
        }
    }
    std::cout << "  Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;
    return best_feature;
}

构建决策树

// 递归构建决策树
std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features) {
    if (data.empty()) {
        return nullptr;
    }

    // 统计标签
    std::unordered_map<std::string, int> label_counts;
    for (const auto& ex : data) {
        label_counts[ex.label]++;
    }

    // 如果所有样本都属于同一类,创建叶子节点
    if (label_counts.size() == 1) {
        auto node = std::make_unique<Node>();
        node->type = NodeType::LEAF;
        node->label = label_counts.begin()->first;
        return node;
    }

    // 如果没有更多特征可用,选择最常见的标签
    if (features.empty()) {
        auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),
            [](const auto& a, const auto& b) {
                return a.second < b.second;
            });
        auto node = std::make_unique<Node>();
        node->type = NodeType::LEAF;
        node->label = most_common_label->first;
        return node;
    }

    // 选择最佳特征
    std::string best_feature = choose_best_feature(data, features);
    auto node = std::make_unique<Node>();
    node->type = NodeType::INTERNAL;
    node->feature = best_feature;

    // 调试信息:显示选择的最佳特征
    std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;

    // 按最佳特征分割数据
    std::map<std::string, Dataset> split_data;
    for (const auto& ex : data) {
        std::string feature_value = get_feature_value(ex, best_feature);
        split_data[feature_value].push_back(ex);
    }

    // 创建剩余特征列表
    std::vector<std::string> remaining_features;
    for (const auto& feature : features) {
        if (feature != best_feature) {
            remaining_features.push_back(feature);
        }
    }

    // 递归构建子树
    for (const auto& [value, subset] : split_data) {
        node->children[value] = build_tree(subset, remaining_features);
    }

    return node;
}

预测

// 预测函数
std::string DecisionTree::predict(const Example& example) const {
    if (!root_) {
        throw std::runtime_error("Decision tree has not been trained");
    }
    return predict_helper(root_.get(), example);
}

// 辅助预测函数
std::string DecisionTree::predict_helper(const Node* node, const Example& example) const {
    if (node->type == NodeType::LEAF) {
        return node->label.value();
    }

    const std::string& feature = node->feature;
    std::string feature_value = get_feature_value(example, feature);

    auto it = node->children.find(feature_value);
    if (it == node->children.end()) {
        throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");
    }

    return predict_helper(it->second.get(), example);
}

代码

        主要功能就在下面,下面附上一份完整代码,包括一些辅助功能函数的实现已经在main函数里面的简单测试:

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <optional>
#include <unordered_map>
#include <algorithm>
#include <cmath>
#include <memory>
#include <cassert>

struct Example{
    int age;
    std::string income;
    std::string marital_status;
    std::string label;
};

using Dataset = std::vector<Example>;

class DecisionTree{
public:
    // 节点类型枚举
    enum class NodeType{
        INTERNAL,
        LEAF
    };
    
    // 节点结构
    struct Node{
        NodeType type;
        std::string feature;
        std::map<std::string, std::unique_ptr<Node>> children;
        std::optional<std::string> label;
    };

    void fit(const Dataset& data, const std::vector<std::string>& features);
    std::string predict(const Example& example) const;

private:
    std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);
    double entropy(const Dataset& data) const;
    double information_gain(const Dataset& data, const std::string& feature) const;
    std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;
    std::string get_feature_value(const Example& example, const std::string& feature) const;
    
    std::unique_ptr<Node> root_;
    std::string predict_helper(const Node* node, const Example& example) const;
};

// 计算熵
double DecisionTree::entropy(const Dataset& data) const{
    if(data.empty()) return 0.0;
    
    std::unordered_map<std::string,int> label_counts;
    for(const auto& ex : data){
        label_counts[ex.label]++;
    }
    
    double total = static_cast<double>(data.size());
    double entropy_value = 0.0;

    for(const auto& [label, count] : label_counts){
        double p = static_cast<double>(count) / total;
        if(p > 0) { // 避免log(0)
            entropy_value -= p * std::log2(p);
        }
    }
    return entropy_value;
}

// 获取特征值的辅助函数
std::string DecisionTree::get_feature_value(const Example& example, const std::string& feature) const{
    if(feature == "income"){
        return example.income;
    }
    else if(feature == "marital_status"){
        return example.marital_status;
    }
    else if(feature == "age"){
        if(example.age < 30){
            return "young";
        }
        else if(example.age >= 30 && example.age < 50){
            return "middle";
        }
        else{
            return "old";
        }
    }
    else{
        throw std::runtime_error("Unknown feature: " + feature);
    }
}

// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const{
    double initial_entropy = entropy(data);
    double weighted_entropy = 0.0;

    std::map<std::string, Dataset> split_data;
    for(const auto& ex : data){
        std::string feature_value = get_feature_value(ex, feature);
        split_data[feature_value].push_back(ex);
    }

    for(const auto& [value, subset] : split_data){
        double weight = static_cast<double>(subset.size()) / data.size();
        weighted_entropy += weight * entropy(subset);
    }

    return initial_entropy - weighted_entropy;
}

// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const{
    double max_gain = -1.0;
    std::string best_feature;

    std::cout << "  Information gains:" << std::endl;
    for(const auto& feature : features){
        double gain = information_gain(data, feature);
        std::cout << "    " << feature << ": " << gain << std::endl;
        if(gain > max_gain){
            max_gain = gain;
            best_feature = feature;
        }
    }
    std::cout << "  Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;
    return best_feature;
}

std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features){
    if(data.empty()){
        return nullptr;
    }
    
    // 统计标签
    std::unordered_map<std::string, int> label_counts;
    for(const auto& ex : data){
        label_counts[ex.label]++;
    }
    
    // 如果所有样本都属于同一类,创建叶子节点
    if(label_counts.size() == 1){
        auto node = std::make_unique<Node>();
        node->type = NodeType::LEAF;
        node->label = label_counts.begin()->first;
        return node;
    }

    // 如果没有更多特征可用,选择最常见的标签
    if(features.empty()){
        auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),
            [](const auto& a, const auto& b){
                return a.second < b.second;
            });
        auto node = std::make_unique<Node>();
        node->type = NodeType::LEAF;
        node->label = most_common_label->first;
        return node;
    }

    // 选择最佳特征
    std::string best_feature = choose_best_feature(data, features);
    auto node = std::make_unique<Node>();
    node->type = NodeType::INTERNAL;
    node->feature = best_feature;
    
    // 调试信息:显示选择的最佳特征
    std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;

    // 按最佳特征分割数据
    std::map<std::string, Dataset> split_data;
    for(const auto& ex : data){
        std::string feature_value = get_feature_value(ex, best_feature);
        split_data[feature_value].push_back(ex);
    }
    
    // 创建剩余特征列表
    std::vector<std::string> remaining_features;
    for(const auto& feature : features){
        if(feature != best_feature){
            remaining_features.push_back(feature);
        }
    }
    
    // 递归构建子树
    for(const auto& [value, subset] : split_data){
        node->children[value] = build_tree(subset, remaining_features);
    }
    
    return node;
}

void DecisionTree::fit(const Dataset& data, const std::vector<std::string>& features){
    root_ = build_tree(data, features);
}

std::string DecisionTree::predict(const Example& example) const{
    if(!root_){
        throw std::runtime_error("Decision tree has not been trained");
    }
    return predict_helper(root_.get(), example);
}

std::string DecisionTree::predict_helper(const Node* node, const Example& example) const{
    if(node->type == NodeType::LEAF){
        return node->label.value();
    }

    const std::string& feature = node->feature;
    std::string feature_value = get_feature_value(example, feature);
    
    auto it = node->children.find(feature_value);
    if (it == node->children.end()) {
        throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");
    }
 
    return predict_helper(it->second.get(), example);
}






int main() {
    try {
        Dataset data = {
            {30, "High", "Single", "Class A"},
            {35, "Low", "Married", "Class B"},
            {40, "Medium", "Divorced", "Class A"},
            {25, "Low", "Single", "Class C"},
            {50, "High", "Married", "Class B"},
            {45, "Low", "Divorced", "Class A"}
        };
     
        std::vector<std::string> features = {"income", "marital_status", "age"};
     
        // 训练决策树
        std::cout << "Training decision tree..." << std::endl;
        DecisionTree dt;
        dt.fit(data, features);
        std::cout << "Training completed!" << std::endl;
     
        // 预测新样本
        Example new_example = {30, "High", "Single", ""};
        std::cout << "Predicting for example: Age=" << new_example.age 
                  << ", Income=" << new_example.income 
                  << ", Marital Status=" << new_example.marital_status << std::endl;
        
        std::string prediction = dt.predict(new_example);
        std::cout << "Prediction: " << prediction << std::endl;
        
        // 测试更多样本
        std::vector<Example> test_examples = {
            {25, "Low", "Single", ""},
            {45, "High", "Married", ""},
            {35, "Medium", "Divorced", ""}
        };
        
        std::cout << "\nTesting additional examples:" << std::endl;
        for(const auto& example : test_examples){
            std::string pred = dt.predict(example);
            std::cout << "Age=" << example.age << ", Income=" << example.income 
                      << ", Marital=" << example.marital_status << " -> " << pred << std::endl;
        }
        
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
    
    return 0;
}

  

结语

        本章到这里就结束了,小H马上就要开启周末的快乐生活了,如果代码上有什么问题可以和博主联系。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱吃芒果的蘑菇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值