当前位置: 首页 > news >正文

【机器学习】两大线性分类算法:逻辑回归与线性判别分析:找到分界线的艺术

文章目录

  • 一、核心概念:数据分类的"切分线"
  • 二、工作原理:从"找分界线"理解
  • 二、常见算法
    • 1、逻辑回归:二分类
    • 2、线性判别分析(LDA):分类与降维
    • 3、两种算法对比分析
  • 三、实际应用:用代码"切分"糖尿病数据
  • 四、应用场景:线性分类的"用武之地"

 

一、核心概念:数据分类的"切分线"

线性分类就像用一把刀"切蛋糕"。如果蛋糕只有两种口味(比如巧克力味和草莓味),你只需要一刀就能把它们完美地分开。这"一刀"就是我们机器学习中的 “决策边界” 。它是一条直线、一个平面,或者在更高维度上的一个超平面,用来区分不同的数据类别。

核心定义
线性分类算法通过学习数据特征,找到一个最佳的超平面,将不同类别的数据点分隔开。这个超平面就是模型的"分界线",它决定了新来的数据点属于哪个类别。

 

二、工作原理:从"找分界线"理解

线性分类算法的工作流程可以简单概括为以下几步:

  1. 数据准备:首先,我们需要收集带有明确标签的数据。比如,如果你想识别猫和狗,就需要大量的猫图片(标记为"猫")和狗图片(标记为"狗")。

  2. 特征提取:计算机无法直接理解图片或文字,所以我们需要将这些原始数据转换为算法能理解的数字特征。比如,图片的颜色、纹理、形状等都可以转化为数字向量

  3. 模型学习:算法会通过学习这些带有标签的特征,自动找到一条最佳的"切分线"(决策边界),使得不同类别的数据点尽可能地被这条线分开。

  4. 预测分类:当有新的、未知的图片(比如一张你不知道是猫还是狗的图片)进来时,算法会根据它落在"切分线"的哪一边,来判断其类别。

 

二、常见算法

在众多线性分类算法中,**逻辑回归(Logistic Regression, LR)线性判别分析(Linear Discriminant Analysis, LDA)**是两种非常经典且常用的方法。

1、逻辑回归:二分类

LR 不直接预测类别,而是预测一个事件发生的概率。它通过一个特殊的"S"形函数(Sigmoid函数),将线性模型的输出值映射到0到1之间,表示属于某个类别的概率。

LR 假设数据服从伯努利分布(即只有两种结果,成功或失败),并且特征与对数几率之间存在线性关系。当预测概率大于某个阈值(通常是0.5)时,就归为一类;否则归为另一类。这个阈值对应的就是它的决策边界。

 

2、线性判别分析(LDA):分类与降维

核心思想
LDA是一个既做分类又做降维的算法。它的核心思想是:找到一个最佳投影方向,让不同类别的数据投影后分得最清楚,同时让同一类别的数据投影后聚得最紧。想象你有一堆不同颜色的球散落在三维空间里,LDA就是找到一个角度,从这个角度看过去,不同颜色的球能分得最清楚。

LDA用公式 类间距离类内距离\frac{\text{类间距离}}{\text{类内距离}}类内距离类间距离 来衡量投影效果。它要最大化这个比值,让分子(类间距离)越大越好,分母(类内距离)越小越好。通过求解这个优化问题,LDA找到最优的投影方向。

 
分类与降维
LDA既可以做二分类,也可以做多分类。在二分类中,它找到一个投影方向将数据投影到一条直线上;在多分类中,它找到多个投影方向将数据投影到低维空间。 同时,LDA还具有降维功能:二分类从任意维度降到1维,C分类从任意维度降到(C-1)维。

 
假设与应用条件
LDA假设数据像钟形曲线一样分布,且不同类别的"钟形"形状相同。 如果这个假设成立,LDA效果很好;如果不成立,效果就会变差。与逻辑回归相比,LDA像是一个"几何学家",专注于找到最好的观察角度,而逻辑回归像是一个"概率学家",专注于计算每个样本属于各类别的概率。

 

3、两种算法对比分析

特性逻辑回归 (LR)线性判别分析 (LDA)
核心预测概率,通过 Sigmoid 函数映射寻找最佳投影方向,最大化类间散度,最小化类内散度
输出概率值 (0-1)判别函数值
假设数据服从伯努利分布,特征与对数几率线性相关数据服从高斯分布,各类别协方差矩阵相同
优点简单高效,输出概率直观,易于解释在类别区分度高时表现良好,对多分类问题处理自然
缺点对异常值敏感,对数据分布有一定要求对数据分布(高斯分布、同协方差)有较强假设

 

三、实际应用:用代码"切分"糖尿病数据

我们用 Python 的 scikit-learn 库,在著名的 Pima Indians 糖尿病数据集上,实际操作一下线性分类算法。这个数据集包含了印第安女性的健康数据,目标是预测她们是否患有糖尿病。

from pandas import read_csv
from sklearn.model_selection import KFold, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import pandas as pd# 1. 导入数据
filename = 'pima_data.csv' # 假设文件在当前目录下
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
data = read_csv(filename, names=names)# 2. 将数据分为输入特征 (X) 和输出结果 (Y)
array = data.values
X = array[:, 0:8]  # 前8列是特征
Y = array[:, 8]    # 最后一列是目标变量 (是否患糖尿病)# 3. 设置交叉验证 (KFold)
n_splits = 10 # 分成10份
seed = 7      # 随机种子,确保每次运行结果一致
kfold = KFold(n_splits=n_splits, random_state=seed, shuffle=True)# 4. 创建模型并评估
models = {'LR': LogisticRegression(solver='liblinear', max_iter=200), # solver='liblinear' 适用于小数据集'LDA': LinearDiscriminantAnalysis()
}results = {}
for name, model in models.items():# 使用交叉验证评估模型性能scores = cross_val_score(model, X, Y, cv=kfold, scoring='accuracy')results[name] = (scores.mean(), scores.std())# 5. 打印结果
print("模型准确率 (平均值 ± 标准差):")
for name, (mean, std) in results.items():print(f"{name}: {mean:.4f}{std:.4f})")# 示例输出 
模型准确率 (平均值 ± 标准差):
LR: 0.7696 (±0.0495)
LDA: 0.7670 (±0.0480)

这告诉我们:

  • LR 模型在预测糖尿病方面的平均准确率约为 76.96%,其性能在不同数据子集上的波动(标准差)约为 4.95%
  • LDA 模型的平均准确率约为 76.70%,性能波动约为 4.80%

从结果来看,两种线性分类模型在 Pima Indians 糖尿病数据集上的表现非常接近,准确率都在76%左右,且波动范围不大。这意味着它们都能有效地对糖尿病进行初步预测。

你可以尝试修改 LogisticRegression 中的 solver 参数(例如改为 'saga''lbfgs'),或者调整 max_iter(最大迭代次数),观察模型性能是否会有变化。这能帮助你理解不同参数对模型训练的影响。

 

四、应用场景:线性分类的"用武之地"

线性分类算法因其简单、高效和易于解释的特点,在许多实际场景中都有广泛应用:

  • 实际案例1:垃圾邮件识别

    • 场景:你的邮箱每天都会自动将邮件分为"正常邮件"和"垃圾邮件"。
    • 选择指导:逻辑回归常用于此,因为它能给出邮件是垃圾邮件的概率。你可以根据这个概率设置一个阈值,比如概率超过80%就直接扔进垃圾箱,低于20%就肯定是正常邮件,中间的再人工判断
  • 实际案例2:客户流失预测

    • 场景:一家电信公司想知道哪些客户可能很快会停止使用他们的服务(流失)。
    • 选择指导:逻辑回归可以根据客户的通话时长、套餐类型、投诉记录等特征,预测客户流失的概率。公司可以根据这些概率,提前对高风险客户采取挽留措施,比如提供优惠套餐。
  • 实际案例3:医疗诊断辅助

    • 场景:医生根据患者的各项体征数据(如血压、血糖、年龄等),辅助诊断某种疾病(如糖尿病、心脏病)。
    • 选择指导:LR 和 LDA 都可以作为初步的诊断模型。LR 提供患病概率,便于医生评估风险;LDA 在类别区分度高时,能更好地找到疾病和健康人群之间的界限
http://www.lryc.cn/news/608291.html

相关文章:

  • uniapp倒计时计算
  • InfluxDB 与 Node.js 框架:Express 集成方案(一)
  • Oracle 11g RAC集群部署手册(一)
  • 电力系统分析学习笔记
  • Angular初学者入门第一课——搭建并改造项目(精品)
  • 学习笔记:无锁队列的原理以及c++实现
  • 基于Dockerfile 部署一个 Flask 应用
  • Orange的运维学习日记--25.Linux文件系统基本管理
  • 【BTC】挖矿
  • 优选算法 力扣1089.复写零 双指针 原地修改 C++解题思路 每日一题
  • Git 的基本使用指南(1)
  • Arpg第二章——流程逻辑
  • 自动驾驶中的传感器技术15——Camera(6)
  • 数字化转型驱动中小制造企业的质量管理升级
  • TFS-2022《A Novel Data-Driven Approach to Autonomous Fuzzy Clustering》
  • 【深度学习②】| DNN篇
  • 编译器与解释器:核心原理与工程实践
  • 基于Postman进行http的请求和响应
  • 操作系统:远程过程调用( Remote Procedure Call,RPC)
  • Jupyter notebook如何显示行号?
  • SQL Server从入门到项目实践(超值版)读书笔记 22
  • Spring事务失效场景
  • kotlin小记(1)
  • 集合框架(重点)
  • linux ext4缩容home,扩容根目录
  • 网络安全基础知识【6】
  • Ext系列文件系统
  • 【软考中级网络工程师】知识点之级联
  • 错误处理_IncompatibleKeys
  • 企业资产|企业资产管理系统|基于springboot企业资产管理系统设计与实现(源码+数据库+文档)