机器学习基础-k 近邻算法(从辨别水果开始)
一、生活中的 "分类难题" 与 k 近邻的灵感
你有没有这样的经历:在超市看到一种从没见过的水果,表皮黄黄的,拳头大小,形状圆滚滚。正当你犹豫要不要买时,突然想起外婆家的橘子好像就是这个样子 —— 黄色、圆形、大小和拳头差不多。于是你推断:"这应该是橘子吧!"
其实,这个看似平常的判断过程,竟然藏着机器学习中最经典的分类算法 ——k 近邻(k-Nearest Neighbors,简称 kNN)的核心思想!
1.1 现实中的解法拆解
当我们判断未知水果时,大脑会自动完成三个步骤:
- 收集特征:观察颜色(黄色)、形状(圆形)、大小(拳头大)
- 匹配经验:调动记忆中 "橘子" 的特征库(黄色、圆形、拳头大)
- 做出判断:因为新水果的特征和记忆中的橘子最像,所以归类为橘子
这和 k 近邻算法的工作流程惊人地相似!唯一的区别是:计算机需要我们把这些 "看得到的特征" 变成 "算得出的数据"。
1.2 k 近邻算法的有趣灵魂
k 近邻算法的有趣之处在于它的 "懒惰" 和 "实在":
- 懒惰:它不像其他算法那样先总结规律(比如 "黄色圆形水果都是橘子"),而是等到需要判断时才去比对已知数据
- 实在:它的判断逻辑简单粗暴 ——"少数服从多数",看新样本周围最像的 k 个样本里哪种类型占多数
就像你纠结新水果是橘子还是苹果时,会找 5 个见过这两种水果的人投票,哪种意见多就信哪种。
二、从生活到代码:k 近邻算法的实现之路
我们用一个具体案例来实现:根据 "颜色深度"(0-10,数值越大越黄)和 "大小"(0-10,数值越大越大)两个特征,判断水果是橘子(标签 1)还是苹果(标签 0)。
2.1 准备数据:把生活观察变成数字
首先,我们需要把已知的水果数据整理成计算机能理解的格式:
# 导入必要的库
import numpy as np # 用于数值计算
import matplotlib.pyplot as plt # 用于画图
# 已知水果数据:[颜色深度, 大小],标签:0=苹果,1=橘子
# 想象这些数据来自我们之前见过的水果:
# 苹果通常偏红(颜色深度小),大小不一;橘子偏黄(颜色深度大)
known_fruits = np.array([
[2, 3], # 苹果:颜色偏红(2),小个(3)
[3, 4], # 苹果:颜色较红(3),中个(4)
[1, 5], # 苹果:颜色很红(1),大个(5)
[7, 6], # 橘子:颜色较黄(7),中个(6)
[8, 5], # 橘子:颜色很黄(8),中个(5)
[9, 4] # 橘子:颜色极黄(9),小个(4)
])
# 对应的标签:0代表苹果,1代表橘子
labels = np.array([0, 0, 0, 1, 1, 1])
# 未知水果:颜色深度6,大小5(就是我们在超市看到的那个)
unknown_fruit = np.array([6, 5])
2.2 数据可视化:让计算机 "看见" 差异
我们用散点图把数据画出来,直观感受苹果和橘子的特征差异:
# 绘制已知水果
plt.scatter(known_fruits[labels==0, 0], known_fruits[labels==0, 1],
color='red', marker='o', label='苹果') # 苹果标为红色圆点
plt.scatter(known_fruits[labels==1, 0], known_fruits[labels==1, 1],
color='orange', marker='o', label='橘子') # 橘子标为橙色圆点
# 绘制未知水果(用五角星标记)
plt.scatter(unknown_fruit[0], unknown_fruit[1],
color='purple', marker='*', s=200, label='未知水果') # 紫色五角星,放大显示
# 加上坐标轴标签和标题
plt.xlabel('颜色深度(0-10,数值越大越黄)')
plt.ylabel('大小(0-10,数值越大越大)')
plt.title('水果特征分布图')
plt.legend() # 显示图例
plt.show() # 展示图像
运行这段代码,你会看到:红色圆点(苹果)集中在左侧(颜色偏红),橙色圆点(橘子)集中在右侧(颜色偏黄),而紫色五角星(未知水果)刚好在橘子群附近 —— 这就是我们肉眼判断的依据!
三、k 近邻算法的核心步骤:用数学实现 "投票选举"
计算机怎么判断未知水果的类别呢?它会执行四个关键步骤,我们一步步用代码实现:
3.1 第一步:计算距离(谁离我最近?)
生活中我们靠 "感觉" 判断相似,计算机则靠 "距离" 计算。最常用的是欧氏距离(就像直尺测量两点距离):
\(distance = \sqrt{(x_1-x_2)^2 + (y_1-y_2)^2}\)
用代码实现这个计算:
def calculate_distance(known_point, unknown_point):
"""
计算两个点之间的欧氏距离
参数:
known_point:已知点的特征(如[2,3])
unknown_point:未知点的特征(如[6,5])
返回:
两点之间的距离
"""
# 计算每个特征的差值平方,再求和,最后开平方
squared_diff = (known_point[0] - unknown_point[0])**2 + (known_point[1] - unknown_point[1])** 2
distance = np.sqrt(squared_diff)
return distance
# 计算未知水果与每个已知水果的距离
distances = []
for fruit in known_fruits:
dist = calculate_distance(fruit, unknown_fruit)
distances.append(dist)
# 打印计算过程,方便理解
print(f"已知水果特征{fruit}与未知水果的距离:{dist:.2f}")
运行后会得到类似这样的结果:
已知水果特征[2 3]与未知水果的距离:4.47
已知水果特征[3 4]与未知水果的距离:3.16
已知水果特征[1 5]与未知水果的距离:5.10
已知水果特征[7 6]与未知水果的距离:1.41 # 这个最近!
已知水果特征[8 5]与未知水果的距离:2.00
已知水果特征[9 4]与未知水果的距离:3.61
3.2 第二步:找邻居(选 k 个最像的)
k 近邻算法中的 "k" 就是要选的邻居数量。比如 k=3,就是找距离最近的 3 个已知水果:
# 把距离和对应的标签组合起来,方便排序
distance_with_label = list(zip(distances, labels))
# 按距离从小到大排序
sorted_distance = sorted(distance_with_label, key=lambda x: x[0])
# 选择k=3个最近的邻居
k = 3
nearest_neighbors = sorted_distance[:k]
print(f"\n距离最近的{k}个邻居是:")
for dist, label in nearest_neighbors:
fruit_type = "橘子" if label == 1 else "苹果"
print(f"距离{dist:.2f},类别:{fruit_type}")
此时会输出:
距离最近的3个邻居是:
距离1.41,类别:橘子
距离2.00,类别:橘子
距离3.16,类别:苹果
3.3 第三步:投票表决(少数服从多数)
看看这 3 个邻居里哪种水果占多数:
# 提取邻居的标签
neighbor_labels = [label for (dist, label) in nearest_neighbors]
# 统计每个标签出现的次数
label_counts = np.bincount(neighbor_labels)
# 找到出现次数最多的标签
predicted_label = np.argmax(label_counts)
# 输出结果
if predicted_label == 1:
print("\n根据k近邻算法判断,这个未知水果是:橘子!")
else:
print("\n根据k近邻算法判断,这个未知水果是:苹果!")
最终结果会显示 "橘子",和我们的直觉判断完全一致!
四、完整代码:可直接运行的 k 近邻分类器
把上面的步骤整合起来,再加上一些优化,就得到了一个完整的 k 近邻分类器:
import numpy as np
import matplotlib.pyplot as plt
class SimpleKNN:
"""简单的k近邻分类器"""
def __init__(self, k=3):
"""
初始化分类器
参数:
k:要选择的邻居数量,默认3个
"""
self.k = k
self.known_data = None # 用于存储已知数据
self.known_labels = None # 用于存储已知标签
def fit(self, X, y):
"""
训练模型(其实就是记住已知数据)
参数:
X:已知样本的特征数据,形状为[样本数, 特征数]
y:已知样本的标签,形状为[样本数]
"""
self.known_data = X
self.known_labels = y
print(f"模型训练完成,记住了{len(X)}个样本")
def predict(self, X):
"""
预测新样本的类别
参数:
X:新样本的特征数据,形状为[特征数]
返回:
预测的标签
"""
# 计算与所有已知样本的距离
distances = []
for data in self.known_data:
# 计算欧氏距离
dist = np.sqrt(np.sum((data - X) **2))
distances.append(dist)
# 把距离和标签绑定,按距离排序
distance_with_label = list(zip(distances, self.known_labels))
sorted_distance = sorted(distance_with_label, key=lambda x: x[0])
# 取前k个邻居的标签
nearest_labels = [label for (dist, label) in sorted_distance[:self.k]]
# 少数服从多数
return np.argmax(np.bincount(nearest_labels))
# ----------------------
# 用水果数据测试我们的分类器
# ----------------------
if __name__ == "__main__":
# 已知水果特征:[颜色深度, 大小]
fruits = np.array([
[2, 3], [3, 4], [1, 5], # 苹果(标签0)
[7, 6], [8, 5], [9, 4] # 橘子(标签1)
])
labels = np.array([0, 0, 0, 1, 1, 1])
# 创建分类器,选择5个邻居(试试把k改成1或5,看结果会不会变)
knn = SimpleKNN(k=5)
# 训练模型(其实就是记住数据)
knn.fit(fruits, labels)
# 要预测的未知水果:颜色深度6,大小5
unknown_fruit = np.array([6, 5])
prediction = knn.predict(unknown_fruit)
# 输出结果
fruit_names = {0: "苹果", 1: "橘子"}
print(f"\n未知水果的特征:颜色深度{unknown_fruit[0]},大小{unknown_fruit[1]}")
print(f"预测结果:这是一个{fruit_names[prediction]}!")
# 画图展示
plt.scatter(fruits[labels==0, 0], fruits[labels==0, 1],
color='red', marker='o', label='苹果')
plt.scatter(fruits[labels==1, 0], fruits[labels==1, 1],
color='orange', marker='o', label='橘子')
plt.scatter(unknown_fruit[0], unknown_fruit[1],
color='purple', marker='*', s=200, label='未知水果')
plt.xlabel('颜色深度(0-10,越大越黄)')
plt.ylabel('大小(0-10,越大越大)')
plt.title(f'k={knn.k}的k近邻分类结果')
plt.legend()
plt.show()
五、k 近邻算法的关键知识点
5.1 如何选择最佳的 k 值?
k 值是 k 近邻算法中最重要的参数:
- k 太小:容易被噪声干扰(比如刚好有个奇怪的苹果长得像橘子)
- k 太大:会把不相关的样本也算进来(比如远在天边的苹果也参与投票)
一个简单的方法是:从 k=3 开始尝试,逐渐增大,看哪个 k 值的预测效果最好。
5.2 特征需要 "标准化"
生活中如果特征的单位不一样(比如一个特征是厘米,一个是千克),会影响距离计算。解决办法是标准化:
# 标准化特征:让每个特征的平均值为0,标准差为1
def standardize(X):
return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
5.3 k 近邻的优缺点
优点:
- 简单易懂,几乎不用数学基础就能理解
- 不需要提前训练模型,拿到新数据可以直接用
- 可以处理多种类型的数据
缺点:
- 数据量大的时候,计算距离会很慢
- 对特征的数量敏感(特征太多时会 "迷路")
六、动手实践:用 scikit-learn 实现更专业的 k 近邻
真实项目中,我们会用成熟的库来实现 k 近邻。试试用 scikit-learn(Python 最流行的机器学习库)重写上面的水果分类:
# 安装scikit-learn(如果没安装的话)
# !pip install scikit-learn
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
# 数据准备(和之前一样)
fruits = np.array([
[2, 3], [3, 4], [1, 5], # 苹果
[7, 6], [8, 5], [9, 4] # 橘子
])
labels = np.array([0, 0, 0, 1, 1, 1])
# 创建k近邻分类器,k=3
knn = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn.fit(fruits, labels)
# 预测未知水果
unknown_fruit = np.array([[6, 5]]) # 注意这里要写成二维数组
prediction = knn.predict(unknown_fruit)
print("scikit-learn预测结果:", "橘子" if prediction[0]==1 else "苹果") # 输出"橘子"
是不是更简单了?这就是专业库的力量!
七、总结:一篇博客掌握 k 近邻
通过辨别水果的例子,我们学会了:
- k 近邻算法的核心思想:"看邻居投票"
- 实现步骤:计算距离→找邻居→投票表决
- 关键参数 k 的选择方法
- 如何用代码实现(从手写简单版本到专业库)
k 近邻就像机器学习世界的 "Hello World",它简单却蕴含了机器学习的基本思想 ——从数据中找规律。下一次当你在超市辨别水果时,不妨想想:"这个过程如果写成代码,应该怎么实现呢?"
现在就动手修改代码里的参数(比如 k 值、水果特征),看看会得到什么有趣的结果吧!
祝你的机器学习之旅,从这个甜甜的 "橘子分类器" 开始,越来越精彩!