当前位置: 首页 > article >正文

MATLAB实战:机器学习分类回归示例

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集
% 鸢尾花数据集(分类)
load fisheriris;
X_iris = meas;      % 150x4 特征矩阵
Y_iris = species;   % 150x1 类别标签

% 手写数字数据集(分类)
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...
    'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[trainImgs, testImgs] = splitEachLabel(imds, 0.7, 'randomized');

% 提取HOG特征
numTrain = numel(trainImgs.Files);
hogFeatures = zeros(numTrain, 324);  % HOG特征维度
for i = 1:numTrain
    img = readimage(trainImgs, i);
    hogFeatures(i, :) = extractHOGFeatures(img);
end
trainLabels = trainImgs.Labels;

% 汽车数据集(回归)
load carsmall;
X_car = [Weight, Horsepower, Cylinders];  % 100x3 特征矩阵
Y_car = MPG;                              % 100x1 响应变量

%% 鸢尾花分类任务
rng(1); % 设置随机种子保证可重复性
cv = cvpartition(Y_iris, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);

% 训练KNN模型
knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5);
knnPred = predict(knnModel, X_iris(idxTest,:));
knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest)

% 训练决策树
treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain));
treePred = predict(treeModel, X_iris(idxTest,:));
treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest)

% 训练SVM
svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain));
svmPred = predict(svmModel, X_iris(idxTest,:));
svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest)

% 混淆矩阵可视化
figure;
confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix');

%% 手写数字分类(使用KNN示例)
% 训练KNN模型
knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3);

% 处理测试集
numTest = numel(testImgs.Files);
testFeatures = zeros(numTest, 324);
testLabels = testImgs.Labels;
for i = 1:numTest
    img = readimage(testImgs, i);
    testFeatures(i, :) = extractHOGFeatures(img);
end

% 预测并评估
digitPred = predict(knnDigitModel, testFeatures);
digitAcc = sum(digitPred == testLabels) / numel(testLabels)

%% 回归任务(汽车数据)
rng(2);
cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25);
idxTrain_car = training(cv_car);
idxTest_car = test(cv_car);

% 线性回归
lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car));
lmPred = predict(lmModel, X_car(idxTest_car,:));
lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 多项式回归(二次项)
polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2');
polyPred = predict(polyModel, X_car(idxTest_car,:));
polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 可视化回归结果
figure;
scatter(Y_car(idxTest_car), lmPred, 'b');
hold on;
scatter(Y_car(idxTest_car), polyPred, 'r');
plot([0,50], [0,50], 'k--');
xlabel('Actual MPG');
ylabel('Predicted MPG');
legend('Linear', 'Polynomial', 'Ideal');
title('Regression Results Comparison');

关键函数说明:

  1. 分类模型训练:

    • fitcknn(): K近邻分类器

    • fitctree(): 决策树分类器

    • fitcecoc(): 多类SVM分类器

  2. 回归模型训练:

    • fitlm(): 线性/多项式回归

    • 'poly2'参数: 指定二次多项式项

  3. 评估指标:

    • confusionchart(): 可视化混淆矩阵

    • loss(): 计算均方误差(回归)

    • 准确率 = 正确预测数/总样本数(分类)

执行结果

鸢尾花分类准确率:
knnAcc = 0.9778
treeAcc = 0.9556
svmAcc = 0.9778

手写数字分类准确率:
digitAcc = 0.9432

回归均方误差:
lmMSE = 15.672
polyMSE = 12.845

注意事项:

  1. 特征工程

    • 手写数字使用HOG特征替代原始像素

    • 汽车数据组合多个特征(重量/马力/气缸数)

  2. 数据预处理

    • 自动处理缺失值(fitlm会排除含NaN的行)

    • 分类数据自动编码(SVM使用整数编码)

  3. 模型优化

    • 可通过crossval函数进行交叉验证

    • 使用HyperparameterOptimization参数自动调优

  4. 可视化

    • 回归结果对比图显示预测值与实际值关系

    • 混淆矩阵直观展示分类错误分布

此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

http://www.lryc.cn/news/2395324.html

相关文章:

  • 动态库导出符号与extern “C“
  • 小知识:STM32 printf 重定向(串口输出)--让数据 “开口说话” 的关键技巧
  • `docker commit` 和 `docker save`区别
  • 【C++ 多态】—— 礼器九鼎,釉下乾坤,多态中的 “风水寻龙诀“
  • SCSAI平台面向对象建模技术的设计与实现
  • pikachu通关教程-CSRF
  • 智能体觉醒:AI开始自己“动手”了-自主进化开启任务革命时代
  • Python爬虫实战:研究Aiohttp库相关技术
  • 【C++指南】C++ list容器完全解读(二):list模拟实现,底层架构揭秘
  • [神经网络]使用olivettiface数据集进行训练并优化,观察对比loss结果
  • 小明的Java面试奇遇之智能家装平台架构设计与JVM调优实战
  • n8n:技术团队的智能工作流自动化助手
  • Flink 核心机制与源码剖析系列
  • 华院计算出席信创论坛,分享AI教育创新实践并与燧原科技共同推出教育一体机
  • 华为OD机试真题——会议接待 /代表团坐车(2025A卷:200分)Java/python/JavaScript/C++/C语言/GO六种最佳实现
  • LabVIEW Val (Sgnl) 属性
  • STM32G4 电机外设篇(三) TIM1 发波 和 ADC COMP DAC级联
  • DAY 35 超大力王爱学Python
  • 【数据结构】图的存储(十字链表)
  • 005 flutter基础,初始文件讲解(4)
  • Redis最佳实践——秒杀系统设计详解
  • STM32软件spi和硬件spi
  • MATLAB实战:人脸检测与识别实现方案
  • 深度刨析树结构(从入门到入土讲解AVL树及红黑树的奥秘)
  • 【Linux】shell的条件判断
  • 第九天:java注解
  • 十一、【核心功能篇】测试用例管理:设计用例新增编辑界面
  • react-native的token认证流程
  • ERP系统中商品定价功能设计:支持渠道、会员与批发场景的灵活定价机制
  • Spring是如何实现属性占位符解析