当前位置: 首页 > news >正文

C++手撕基于ID3算法的决策树

背景

        书接上回,完成了KNN之后,小H又继续学习机器学习相关内容,这一次看到的是决策树,构建一个棵树来进行分类任务,确实是非常形象呢。

概念

        决策树是一种监督学习算法,常用于分类任务。ID3 算法通过计算信息增益来选择最优特征进行分裂,最终生成一棵树状结构,内部节点表示一个特征/属性,叶子节点表示一个类别。信息增益是信息论中的一个概念,用于衡量某个特征分裂数据集前后信息的减少量。

        ID3 算法的核心是信息增益(Information Gain),即通过计算每个特征对数据集分类的 “贡献度”,选择贡献度最大的特征作为当前节点的划分依据,直至所有样本被正确分类或无法继续划分。

关键概念

  1. 信息熵(Entropy) 信息熵是衡量数据集 “混乱程度” 的指标,熵值越高,数据越混乱(分类越不明确)。 

  2. 条件熵(Conditional Entropy) 当用特征A划分数据集D时,划分后的数据子集的平均信息熵称为条件熵。

  3. 信息增益(Information Gain) 信息增益是 “原始数据集的熵” 与 “按特征A划分后的条件熵” 的差值,衡量特征A对分类的贡献。信息增益越大,说明用特征A划分后的数据 “混乱程度降低越多”,该特征越适合作为当前节点的划分依据。

(以上几种公式就可以自行搜索)

ID3 算法流程

  1. 初始化:将所有训练样本作为根节点的数据集。
  2. 终止条件判断
    • 若当前数据集所有样本属于同一类别,将该节点标记为叶节点,返回类别。
    • 若没有剩余特征可划分,将该节点标记为叶节点,返回样本中占比最高的类别(多数表决)。
  3. 选择最优特征
    • 计算当前数据集的信息熵H(D)。
    • 对每个未使用的特征A,计算其信息增益Gain(D,A)。
    • 选择信息增益最大的特征A作为当前节点的划分特征。
  4. 划分数据集
    • 按特征A的所有取值,将数据集拆分为多个子集D_j(每个取值对应一个子集)。
    • 为每个子集创建子节点,递归执行步骤 2-4,直至满足终止条件。

数据准备

        就假设我们有某个数据集,列名是:Age、income、Marital Status、Label,然后若干行吧。

数据结构

        我们将整个决策树封装到DecisionTree类中。

DecisionTree类

class DecisionTree{
public:// 节点类型枚举enum class NodeType{INTERNAL,LEAF};// 节点结构struct Node{NodeType type;std::string feature;std::map<std::string, std::unique_ptr<Node>> children;std::optional<std::string> label;};void fit(const Dataset& data, const std::vector<std::string>& features);std::string predict(const Example& example) const;private:std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);double entropy(const Dataset& data) const;double information_gain(const Dataset& data, const std::string& feature) const;std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;std::string get_feature_value(const Example& example, const std::string& feature) const;std::unique_ptr<Node> root_;std::string predict_helper(const Node* node, const Example& example) const;
};

方法实现

计算熵

// 计算熵
double DecisionTree::entropy(const Dataset& data) const {if (data.empty()) return 0.0;std::unordered_map<std::string, int> label_counts;for (const auto& ex : data) {label_counts[ex.label]++;}double total = static_cast<double>(data.size());double entropy_value = 0.0;for (const auto& [label, count] : label_counts) {double p = static_cast<double>(count) / total;if (p > 0) { // 避免 log(0)entropy_value -= p * std::log2(p);}}return entropy_value;
}

计算信息增益

// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const {double initial_entropy = entropy(data);double weighted_entropy = 0.0;std::map<std::string, Dataset> split_data;for (const auto& ex : data) {std::string feature_value = get_feature_value(ex, feature);split_data[feature_value].push_back(ex);}for (const auto& [value, subset] : split_data) {double weight = static_cast<double>(subset.size()) / data.size();weighted_entropy += weight * entropy(subset);}return initial_entropy - weighted_entropy;
}

选择最佳特征

// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const {double max_gain = -1.0;std::string best_feature;std::cout << "  Information gains:" << std::endl;for (const auto& feature : features) {double gain = information_gain(data, feature);std::cout << "    " << feature << ": " << gain << std::endl;if (gain > max_gain) {max_gain = gain;best_feature = feature;}}std::cout << "  Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;return best_feature;
}

构建决策树

// 递归构建决策树
std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features) {if (data.empty()) {return nullptr;}// 统计标签std::unordered_map<std::string, int> label_counts;for (const auto& ex : data) {label_counts[ex.label]++;}// 如果所有样本都属于同一类,创建叶子节点if (label_counts.size() == 1) {auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = label_counts.begin()->first;return node;}// 如果没有更多特征可用,选择最常见的标签if (features.empty()) {auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),[](const auto& a, const auto& b) {return a.second < b.second;});auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = most_common_label->first;return node;}// 选择最佳特征std::string best_feature = choose_best_feature(data, features);auto node = std::make_unique<Node>();node->type = NodeType::INTERNAL;node->feature = best_feature;// 调试信息:显示选择的最佳特征std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;// 按最佳特征分割数据std::map<std::string, Dataset> split_data;for (const auto& ex : data) {std::string feature_value = get_feature_value(ex, best_feature);split_data[feature_value].push_back(ex);}// 创建剩余特征列表std::vector<std::string> remaining_features;for (const auto& feature : features) {if (feature != best_feature) {remaining_features.push_back(feature);}}// 递归构建子树for (const auto& [value, subset] : split_data) {node->children[value] = build_tree(subset, remaining_features);}return node;
}

预测

// 预测函数
std::string DecisionTree::predict(const Example& example) const {if (!root_) {throw std::runtime_error("Decision tree has not been trained");}return predict_helper(root_.get(), example);
}// 辅助预测函数
std::string DecisionTree::predict_helper(const Node* node, const Example& example) const {if (node->type == NodeType::LEAF) {return node->label.value();}const std::string& feature = node->feature;std::string feature_value = get_feature_value(example, feature);auto it = node->children.find(feature_value);if (it == node->children.end()) {throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");}return predict_helper(it->second.get(), example);
}

代码

        主要功能就在下面,下面附上一份完整代码,包括一些辅助功能函数的实现已经在main函数里面的简单测试:

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <optional>
#include <unordered_map>
#include <algorithm>
#include <cmath>
#include <memory>
#include <cassert>struct Example{int age;std::string income;std::string marital_status;std::string label;
};using Dataset = std::vector<Example>;class DecisionTree{
public:// 节点类型枚举enum class NodeType{INTERNAL,LEAF};// 节点结构struct Node{NodeType type;std::string feature;std::map<std::string, std::unique_ptr<Node>> children;std::optional<std::string> label;};void fit(const Dataset& data, const std::vector<std::string>& features);std::string predict(const Example& example) const;private:std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);double entropy(const Dataset& data) const;double information_gain(const Dataset& data, const std::string& feature) const;std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;std::string get_feature_value(const Example& example, const std::string& feature) const;std::unique_ptr<Node> root_;std::string predict_helper(const Node* node, const Example& example) const;
};// 计算熵
double DecisionTree::entropy(const Dataset& data) const{if(data.empty()) return 0.0;std::unordered_map<std::string,int> label_counts;for(const auto& ex : data){label_counts[ex.label]++;}double total = static_cast<double>(data.size());double entropy_value = 0.0;for(const auto& [label, count] : label_counts){double p = static_cast<double>(count) / total;if(p > 0) { // 避免log(0)entropy_value -= p * std::log2(p);}}return entropy_value;
}// 获取特征值的辅助函数
std::string DecisionTree::get_feature_value(const Example& example, const std::string& feature) const{if(feature == "income"){return example.income;}else if(feature == "marital_status"){return example.marital_status;}else if(feature == "age"){if(example.age < 30){return "young";}else if(example.age >= 30 && example.age < 50){return "middle";}else{return "old";}}else{throw std::runtime_error("Unknown feature: " + feature);}
}// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const{double initial_entropy = entropy(data);double weighted_entropy = 0.0;std::map<std::string, Dataset> split_data;for(const auto& ex : data){std::string feature_value = get_feature_value(ex, feature);split_data[feature_value].push_back(ex);}for(const auto& [value, subset] : split_data){double weight = static_cast<double>(subset.size()) / data.size();weighted_entropy += weight * entropy(subset);}return initial_entropy - weighted_entropy;
}// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const{double max_gain = -1.0;std::string best_feature;std::cout << "  Information gains:" << std::endl;for(const auto& feature : features){double gain = information_gain(data, feature);std::cout << "    " << feature << ": " << gain << std::endl;if(gain > max_gain){max_gain = gain;best_feature = feature;}}std::cout << "  Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;return best_feature;
}std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features){if(data.empty()){return nullptr;}// 统计标签std::unordered_map<std::string, int> label_counts;for(const auto& ex : data){label_counts[ex.label]++;}// 如果所有样本都属于同一类,创建叶子节点if(label_counts.size() == 1){auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = label_counts.begin()->first;return node;}// 如果没有更多特征可用,选择最常见的标签if(features.empty()){auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),[](const auto& a, const auto& b){return a.second < b.second;});auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = most_common_label->first;return node;}// 选择最佳特征std::string best_feature = choose_best_feature(data, features);auto node = std::make_unique<Node>();node->type = NodeType::INTERNAL;node->feature = best_feature;// 调试信息:显示选择的最佳特征std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;// 按最佳特征分割数据std::map<std::string, Dataset> split_data;for(const auto& ex : data){std::string feature_value = get_feature_value(ex, best_feature);split_data[feature_value].push_back(ex);}// 创建剩余特征列表std::vector<std::string> remaining_features;for(const auto& feature : features){if(feature != best_feature){remaining_features.push_back(feature);}}// 递归构建子树for(const auto& [value, subset] : split_data){node->children[value] = build_tree(subset, remaining_features);}return node;
}void DecisionTree::fit(const Dataset& data, const std::vector<std::string>& features){root_ = build_tree(data, features);
}std::string DecisionTree::predict(const Example& example) const{if(!root_){throw std::runtime_error("Decision tree has not been trained");}return predict_helper(root_.get(), example);
}std::string DecisionTree::predict_helper(const Node* node, const Example& example) const{if(node->type == NodeType::LEAF){return node->label.value();}const std::string& feature = node->feature;std::string feature_value = get_feature_value(example, feature);auto it = node->children.find(feature_value);if (it == node->children.end()) {throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");}return predict_helper(it->second.get(), example);
}int main() {try {Dataset data = {{30, "High", "Single", "Class A"},{35, "Low", "Married", "Class B"},{40, "Medium", "Divorced", "Class A"},{25, "Low", "Single", "Class C"},{50, "High", "Married", "Class B"},{45, "Low", "Divorced", "Class A"}};std::vector<std::string> features = {"income", "marital_status", "age"};// 训练决策树std::cout << "Training decision tree..." << std::endl;DecisionTree dt;dt.fit(data, features);std::cout << "Training completed!" << std::endl;// 预测新样本Example new_example = {30, "High", "Single", ""};std::cout << "Predicting for example: Age=" << new_example.age << ", Income=" << new_example.income << ", Marital Status=" << new_example.marital_status << std::endl;std::string prediction = dt.predict(new_example);std::cout << "Prediction: " << prediction << std::endl;// 测试更多样本std::vector<Example> test_examples = {{25, "Low", "Single", ""},{45, "High", "Married", ""},{35, "Medium", "Divorced", ""}};std::cout << "\nTesting additional examples:" << std::endl;for(const auto& example : test_examples){std::string pred = dt.predict(example);std::cout << "Age=" << example.age << ", Income=" << example.income << ", Marital=" << example.marital_status << " -> " << pred << std::endl;}} catch (const std::exception& e) {std::cerr << "Error: " << e.what() << std::endl;return 1;}return 0;
}

  

结语

        本章到这里就结束了,小H马上就要开启周末的快乐生活了,如果代码上有什么问题可以和博主联系。

http://www.lryc.cn/news/608654.html

相关文章:

  • 著作权登记遇难题:创作者如何突破确权困境?
  • 自动驾驶中的传感器技术19——Camera(10)
  • ELECTRICAL靶场
  • ClickHouse Windows迁移方案与测试
  • 【动态规划算法】路径问题
  • WebRTC前处理模块技术详解:音频3A处理与视频优化实践
  • Node.js (Express) + MySQL + Redis构建项目流程
  • 决策树的实际案例
  • sqli-labs:Less-25关卡详细解析
  • 50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | TodoList(代办事项组件)
  • 子区间问题
  • 主机序列号的修改方法与原理
  • Azure DevOps 中的代理
  • 渗透作业4
  • LeetCode - 合并两个有序链表 / 删除链表的倒数第 N 个结点
  • webrtc弱网-QualityScaler 源码分析与算法原理
  • PLC传感器接线与输出信号接线
  • WSUS服务器数据库维护与性能优化技术白皮书
  • 力扣 hot100 Day64
  • 六、Linux核心服务与包管理
  • 若没有安全可靠性保障,对于工程应用而言,AI或许就是大玩具吗?
  • Python黑科技:用@property优雅管理你的属性访问
  • ThinkPHP5x,struts2等框架靶场复现
  • 控制建模matlab练习10:滞后补偿器
  • 吴恩达【prompt提示词工程】学习笔记
  • MCP革命:Anthropic如何重新定义AI与外部世界的连接标准
  • 2.4.1-2.4.3控制范围-控制进度-控制成本
  • STM32复位电路解析
  • Rustdesk中继服务器搭建(windows 服务器)
  • 蜂群优化算法:智能优化新突破