背景
书接上回,完成了KNN之后,小H又继续学习机器学习相关内容,这一次看到的是决策树,构建一个棵树来进行分类任务,确实是非常形象呢。
概念
决策树是一种监督学习算法,常用于分类任务。ID3 算法通过计算信息增益来选择最优特征进行分裂,最终生成一棵树状结构,内部节点表示一个特征/属性,叶子节点表示一个类别。信息增益是信息论中的一个概念,用于衡量某个特征分裂数据集前后信息的减少量。
ID3 算法的核心是信息增益(Information Gain),即通过计算每个特征对数据集分类的 “贡献度”,选择贡献度最大的特征作为当前节点的划分依据,直至所有样本被正确分类或无法继续划分。
关键概念
-
信息熵(Entropy) 信息熵是衡量数据集 “混乱程度” 的指标,熵值越高,数据越混乱(分类越不明确)。
-
条件熵(Conditional Entropy) 当用特征A划分数据集D时,划分后的数据子集的平均信息熵称为条件熵。
-
信息增益(Information Gain) 信息增益是 “原始数据集的熵” 与 “按特征A划分后的条件熵” 的差值,衡量特征A对分类的贡献。信息增益越大,说明用特征A划分后的数据 “混乱程度降低越多”,该特征越适合作为当前节点的划分依据。
(以上几种公式就可以自行搜索)
ID3 算法流程
- 初始化:将所有训练样本作为根节点的数据集。
- 终止条件判断:
- 若当前数据集所有样本属于同一类别,将该节点标记为叶节点,返回类别。
- 若没有剩余特征可划分,将该节点标记为叶节点,返回样本中占比最高的类别(多数表决)。
- 选择最优特征:
- 计算当前数据集的信息熵H(D)。
- 对每个未使用的特征A,计算其信息增益Gain(D,A)。
- 选择信息增益最大的特征A作为当前节点的划分特征。
- 划分数据集:
- 按特征A的所有取值,将数据集拆分为多个子集
(每个取值对应一个子集)。
- 为每个子集创建子节点,递归执行步骤 2-4,直至满足终止条件。
- 按特征A的所有取值,将数据集拆分为多个子集
数据准备
就假设我们有某个数据集,列名是:Age、income、Marital Status、Label,然后若干行吧。
数据结构
我们将整个决策树封装到DecisionTree类中。
DecisionTree类
class DecisionTree{
public:
// 节点类型枚举
enum class NodeType{
INTERNAL,
LEAF
};
// 节点结构
struct Node{
NodeType type;
std::string feature;
std::map<std::string, std::unique_ptr<Node>> children;
std::optional<std::string> label;
};
void fit(const Dataset& data, const std::vector<std::string>& features);
std::string predict(const Example& example) const;
private:
std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);
double entropy(const Dataset& data) const;
double information_gain(const Dataset& data, const std::string& feature) const;
std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;
std::string get_feature_value(const Example& example, const std::string& feature) const;
std::unique_ptr<Node> root_;
std::string predict_helper(const Node* node, const Example& example) const;
};
方法实现
计算熵
// 计算熵
double DecisionTree::entropy(const Dataset& data) const {
if (data.empty()) return 0.0;
std::unordered_map<std::string, int> label_counts;
for (const auto& ex : data) {
label_counts[ex.label]++;
}
double total = static_cast<double>(data.size());
double entropy_value = 0.0;
for (const auto& [label, count] : label_counts) {
double p = static_cast<double>(count) / total;
if (p > 0) { // 避免 log(0)
entropy_value -= p * std::log2(p);
}
}
return entropy_value;
}
计算信息增益
// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const {
double initial_entropy = entropy(data);
double weighted_entropy = 0.0;
std::map<std::string, Dataset> split_data;
for (const auto& ex : data) {
std::string feature_value = get_feature_value(ex, feature);
split_data[feature_value].push_back(ex);
}
for (const auto& [value, subset] : split_data) {
double weight = static_cast<double>(subset.size()) / data.size();
weighted_entropy += weight * entropy(subset);
}
return initial_entropy - weighted_entropy;
}
选择最佳特征
// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const {
double max_gain = -1.0;
std::string best_feature;
std::cout << " Information gains:" << std::endl;
for (const auto& feature : features) {
double gain = information_gain(data, feature);
std::cout << " " << feature << ": " << gain << std::endl;
if (gain > max_gain) {
max_gain = gain;
best_feature = feature;
}
}
std::cout << " Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;
return best_feature;
}
构建决策树
// 递归构建决策树
std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features) {
if (data.empty()) {
return nullptr;
}
// 统计标签
std::unordered_map<std::string, int> label_counts;
for (const auto& ex : data) {
label_counts[ex.label]++;
}
// 如果所有样本都属于同一类,创建叶子节点
if (label_counts.size() == 1) {
auto node = std::make_unique<Node>();
node->type = NodeType::LEAF;
node->label = label_counts.begin()->first;
return node;
}
// 如果没有更多特征可用,选择最常见的标签
if (features.empty()) {
auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),
[](const auto& a, const auto& b) {
return a.second < b.second;
});
auto node = std::make_unique<Node>();
node->type = NodeType::LEAF;
node->label = most_common_label->first;
return node;
}
// 选择最佳特征
std::string best_feature = choose_best_feature(data, features);
auto node = std::make_unique<Node>();
node->type = NodeType::INTERNAL;
node->feature = best_feature;
// 调试信息:显示选择的最佳特征
std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;
// 按最佳特征分割数据
std::map<std::string, Dataset> split_data;
for (const auto& ex : data) {
std::string feature_value = get_feature_value(ex, best_feature);
split_data[feature_value].push_back(ex);
}
// 创建剩余特征列表
std::vector<std::string> remaining_features;
for (const auto& feature : features) {
if (feature != best_feature) {
remaining_features.push_back(feature);
}
}
// 递归构建子树
for (const auto& [value, subset] : split_data) {
node->children[value] = build_tree(subset, remaining_features);
}
return node;
}
预测
// 预测函数
std::string DecisionTree::predict(const Example& example) const {
if (!root_) {
throw std::runtime_error("Decision tree has not been trained");
}
return predict_helper(root_.get(), example);
}
// 辅助预测函数
std::string DecisionTree::predict_helper(const Node* node, const Example& example) const {
if (node->type == NodeType::LEAF) {
return node->label.value();
}
const std::string& feature = node->feature;
std::string feature_value = get_feature_value(example, feature);
auto it = node->children.find(feature_value);
if (it == node->children.end()) {
throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");
}
return predict_helper(it->second.get(), example);
}
代码
主要功能就在下面,下面附上一份完整代码,包括一些辅助功能函数的实现已经在main函数里面的简单测试:
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <optional>
#include <unordered_map>
#include <algorithm>
#include <cmath>
#include <memory>
#include <cassert>
struct Example{
int age;
std::string income;
std::string marital_status;
std::string label;
};
using Dataset = std::vector<Example>;
class DecisionTree{
public:
// 节点类型枚举
enum class NodeType{
INTERNAL,
LEAF
};
// 节点结构
struct Node{
NodeType type;
std::string feature;
std::map<std::string, std::unique_ptr<Node>> children;
std::optional<std::string> label;
};
void fit(const Dataset& data, const std::vector<std::string>& features);
std::string predict(const Example& example) const;
private:
std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);
double entropy(const Dataset& data) const;
double information_gain(const Dataset& data, const std::string& feature) const;
std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;
std::string get_feature_value(const Example& example, const std::string& feature) const;
std::unique_ptr<Node> root_;
std::string predict_helper(const Node* node, const Example& example) const;
};
// 计算熵
double DecisionTree::entropy(const Dataset& data) const{
if(data.empty()) return 0.0;
std::unordered_map<std::string,int> label_counts;
for(const auto& ex : data){
label_counts[ex.label]++;
}
double total = static_cast<double>(data.size());
double entropy_value = 0.0;
for(const auto& [label, count] : label_counts){
double p = static_cast<double>(count) / total;
if(p > 0) { // 避免log(0)
entropy_value -= p * std::log2(p);
}
}
return entropy_value;
}
// 获取特征值的辅助函数
std::string DecisionTree::get_feature_value(const Example& example, const std::string& feature) const{
if(feature == "income"){
return example.income;
}
else if(feature == "marital_status"){
return example.marital_status;
}
else if(feature == "age"){
if(example.age < 30){
return "young";
}
else if(example.age >= 30 && example.age < 50){
return "middle";
}
else{
return "old";
}
}
else{
throw std::runtime_error("Unknown feature: " + feature);
}
}
// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const{
double initial_entropy = entropy(data);
double weighted_entropy = 0.0;
std::map<std::string, Dataset> split_data;
for(const auto& ex : data){
std::string feature_value = get_feature_value(ex, feature);
split_data[feature_value].push_back(ex);
}
for(const auto& [value, subset] : split_data){
double weight = static_cast<double>(subset.size()) / data.size();
weighted_entropy += weight * entropy(subset);
}
return initial_entropy - weighted_entropy;
}
// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const{
double max_gain = -1.0;
std::string best_feature;
std::cout << " Information gains:" << std::endl;
for(const auto& feature : features){
double gain = information_gain(data, feature);
std::cout << " " << feature << ": " << gain << std::endl;
if(gain > max_gain){
max_gain = gain;
best_feature = feature;
}
}
std::cout << " Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;
return best_feature;
}
std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features){
if(data.empty()){
return nullptr;
}
// 统计标签
std::unordered_map<std::string, int> label_counts;
for(const auto& ex : data){
label_counts[ex.label]++;
}
// 如果所有样本都属于同一类,创建叶子节点
if(label_counts.size() == 1){
auto node = std::make_unique<Node>();
node->type = NodeType::LEAF;
node->label = label_counts.begin()->first;
return node;
}
// 如果没有更多特征可用,选择最常见的标签
if(features.empty()){
auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),
[](const auto& a, const auto& b){
return a.second < b.second;
});
auto node = std::make_unique<Node>();
node->type = NodeType::LEAF;
node->label = most_common_label->first;
return node;
}
// 选择最佳特征
std::string best_feature = choose_best_feature(data, features);
auto node = std::make_unique<Node>();
node->type = NodeType::INTERNAL;
node->feature = best_feature;
// 调试信息:显示选择的最佳特征
std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;
// 按最佳特征分割数据
std::map<std::string, Dataset> split_data;
for(const auto& ex : data){
std::string feature_value = get_feature_value(ex, best_feature);
split_data[feature_value].push_back(ex);
}
// 创建剩余特征列表
std::vector<std::string> remaining_features;
for(const auto& feature : features){
if(feature != best_feature){
remaining_features.push_back(feature);
}
}
// 递归构建子树
for(const auto& [value, subset] : split_data){
node->children[value] = build_tree(subset, remaining_features);
}
return node;
}
void DecisionTree::fit(const Dataset& data, const std::vector<std::string>& features){
root_ = build_tree(data, features);
}
std::string DecisionTree::predict(const Example& example) const{
if(!root_){
throw std::runtime_error("Decision tree has not been trained");
}
return predict_helper(root_.get(), example);
}
std::string DecisionTree::predict_helper(const Node* node, const Example& example) const{
if(node->type == NodeType::LEAF){
return node->label.value();
}
const std::string& feature = node->feature;
std::string feature_value = get_feature_value(example, feature);
auto it = node->children.find(feature_value);
if (it == node->children.end()) {
throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");
}
return predict_helper(it->second.get(), example);
}
int main() {
try {
Dataset data = {
{30, "High", "Single", "Class A"},
{35, "Low", "Married", "Class B"},
{40, "Medium", "Divorced", "Class A"},
{25, "Low", "Single", "Class C"},
{50, "High", "Married", "Class B"},
{45, "Low", "Divorced", "Class A"}
};
std::vector<std::string> features = {"income", "marital_status", "age"};
// 训练决策树
std::cout << "Training decision tree..." << std::endl;
DecisionTree dt;
dt.fit(data, features);
std::cout << "Training completed!" << std::endl;
// 预测新样本
Example new_example = {30, "High", "Single", ""};
std::cout << "Predicting for example: Age=" << new_example.age
<< ", Income=" << new_example.income
<< ", Marital Status=" << new_example.marital_status << std::endl;
std::string prediction = dt.predict(new_example);
std::cout << "Prediction: " << prediction << std::endl;
// 测试更多样本
std::vector<Example> test_examples = {
{25, "Low", "Single", ""},
{45, "High", "Married", ""},
{35, "Medium", "Divorced", ""}
};
std::cout << "\nTesting additional examples:" << std::endl;
for(const auto& example : test_examples){
std::string pred = dt.predict(example);
std::cout << "Age=" << example.age << ", Income=" << example.income
<< ", Marital=" << example.marital_status << " -> " << pred << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
结语
本章到这里就结束了,小H马上就要开启周末的快乐生活了,如果代码上有什么问题可以和博主联系。
8万+

被折叠的 条评论
为什么被折叠?



