知识蒸馏 Jensen-Shannon散度
知识蒸馏 Jensen-Shannon散度
flyfish
Jensen-Shannon散度(简称JSD)是衡量两个概率分布“差异程度”的工具,它最大的特点是对称的——也就是说,衡量P和Q的差异,与衡量Q和P的差异,结果完全一样。
先理解什么是“概率分布”?
在说JSD之前,先明确“概率分布”的含义。
掷一个公平骰子,每个面(1-6)出现的概率都是1/6,这就是一个概率分布P:P = [1/6, 1/6, 1/6, 1/6, 1/6, 1/6]
。
如果骰子被做了手脚,掷出1的概率是1/2,其他面各1/10,这是另一个分布Q:Q = [1/2, 1/10, 1/10, 1/10, 1/10, 1/10]
。
常需要比较两个分布(比如上面的P和Q)有多“像”或多“不像”,JSD就是干这个的。
概率分布(Probability Distribution)是对一个随机变量的所有可能结果及其对应概率的完整描述。
它不是单个事件的概率(比如“掷骰子出1点的概率是1/6”),而是把所有可能发生的事件,以及它们各自的概率,系统地列出来。
特点(以离散概率分布为例):
- 包含所有可能结果:比如掷六面骰子,结果只能是1、2、3、4、5、6,没有其他可能。
- 每个结果的概率≥0:概率不能是负数(不可能发生的事件概率为0)。
- 所有概率的总和=1:因为“所有可能结果中必然有一个发生”,整体概率是100%。
P=[1/6,1/6,1/6,1/6,1/6,1/6]P = [1/6, 1/6, 1/6, 1/6, 1/6, 1/6]P=[1/6,1/6,1/6,1/6,1/6,1/6]就是一个典型的离散概率分布:
- 对应随机变量“掷骰子的结果”;
- 结果有6个(1到6);
- 每个结果的概率都是1/6;
- 所有概率相加:1/6+1/6+...+1/6=11/6 + 1/6 + ... + 1/6 = 11/6+1/6+...+1/6=1,符合概率分布的要求。
抛一枚不均匀的硬币,正面(H)概率0.6,反面(T)概率0.4,它的概率分布可以表示为:
P={H:0.6,T:0.4}P = \{ H: 0.6, T: 0.4 \}P={H:0.6,T:0.4} (同样满足“所有概率非负、总和为1”)
连续概率分布
如果随机变量的结果是连续的(比如“人的身高”“温度”),概率分布会用概率密度函数(PDF) 描述,核心思想类似:函数在某一区间的积分表示该区间内事件发生的概率,且整个定义域上的积分等于1。
概率是“单个事件的可能性”,而概率分布是“所有可能事件的概率的整体清单”,它完整刻画了随机变量的行为。
为什么需要JSD?先看它的“前辈”KL散度
在JSD之前,常用KL散度(Kullback-Leibler散度)衡量分布差异,但KL散度有个明显缺点:不对称。
1. KL散度的定义(简单)
对于两个离散概率分布P和Q(比如上面的骰子分布),KL散度的公式是:
KL(P∣∣Q)=∑iP(i)⋅log(P(i)Q(i))\text{KL}(P||Q) = \sum_{i} P(i) \cdot \log\left( \frac{P(i)}{Q(i)} \right)KL(P∣∣Q)=i∑P(i)⋅log(Q(i)P(i))
- 遍历每个可能的结果(比如骰子的每个面i);
- 对每个i,计算P的概率乘以“P的概率除以Q的概率”的对数;
- 把这些值加起来,就是P相对于Q的KL散度。
2. KL散度的问题:不对称
比如用上面的P(公平骰子)和Q(作弊骰子)计算:
- KL(P||Q) 是“以P为基准,看Q有多不同”;
- KL(Q||P) 是“以Q为基准,看P有多不同”。
这两个结果不一样!比如Q中1的概率很高(1/2),而P中1的概率低(1/6),所以KL(Q||P)会比KL(P||Q)大很多。
但很多时候,需要“无偏”地比较两个分布(比如“P和Q的差异”,而不是“以谁为基准的差异”),这时候就需要JSD了。
Jensen-Shannon散度(JSD):对称的差异度量
JSD的核心思路是:通过一个“中间分布”,把KL散度的不对称性修正为对称。
1. 步骤1:定义“中间分布M”
先找一个P和Q的“平均分布”M,公式很简单:
M=12(P+Q)M = \frac{1}{2}(P + Q)M=21(P+Q)
M中每个结果的概率,是P和Q对应概率的平均值。
比如上面的P和Q:
- P中1的概率是1/6,Q中1的概率是1/2,所以M中1的概率是 (1/6 + 1/2)/2 = (2/3)/2 = 1/3;
- 其他面(比如2):P中是1/6,Q中是1/10,所以M中是 (1/6 + 1/10)/2 = (4/15)/2 = 2/15。
2. 步骤2:计算两个KL散度
分别计算P到M的KL散度,和Q到M的KL散度:
- KL(P||M):P相对于M的差异;
- KL(Q||M):Q相对于M的差异。
3. 步骤3:取平均,得到JSD
JSD的公式就是这两个KL散度的平均值:
JSD(P∣∣Q)=12[KL(P∣∣M)+KL(Q∣∣M)]\text{JSD}(P||Q) = \frac{1}{2} \left[ \text{KL}(P||M) + \text{KL}(Q||M) \right]JSD(P∣∣Q)=21[KL(P∣∣M)+KL(Q∣∣M)]
因为M是P和Q的平均(交换P和Q,M不变),所以:
JSD(P∣∣Q)=JSD(Q∣∣P)\text{JSD}(P||Q) = \text{JSD}(Q||P)JSD(P∣∣Q)=JSD(Q∣∣P)
这就是JSD的核心优势——它衡量的是“P和Q之间的客观差异”,不依赖谁是“基准”。
例子:抛硬币的概率分布差异
- 分布 PPP:硬币正面概率0.8,反面概率0.2(即 P=[0.8,0.2]P = [0.8, 0.2]P=[0.8,0.2],对应“正面”“反面”两个结果);
- 分布 QQQ:硬币正面概率0.2,反面概率0.8(即 Q=[0.2,0.8]Q = [0.2, 0.8]Q=[0.2,0.8])。
步骤1:计算中间分布 MMM
中间分布 MMM是 PPP和 QQQ的平均,公式为:
M=12(P+Q) M = \frac{1}{2}(P + Q) M=21(P+Q)
代入具体值:
正面概率:12×(0.8+0.2)=0.5\frac{1}{2} \times (0.8 + 0.2) = 0.521×(0.8+0.2)=0.5;
反面概率:12×(0.2+0.8)=0.5\frac{1}{2} \times (0.2 + 0.8) = 0.521×(0.2+0.8)=0.5;
因此,M=[0.5,0.5]M = [0.5, 0.5]M=[0.5,0.5]。
步骤2:计算KL散度 KL(P∥M)\text{KL}(P\|M)KL(P∥M)和 KL(Q∥M)\text{KL}(Q\|M)KL(Q∥M)
KL散度的公式(以2为底的对数,单位为“比特”)为:
KL(A∥B)=∑iA(i)⋅log2(A(i)B(i)) \text{KL}(A\|B) = \sum_{i} A(i) \cdot \log_2\left( \frac{A(i)}{B(i)} \right) KL(A∥B)=i∑A(i)⋅log2(B(i)A(i))
其中 A(i)A(i)A(i)和 B(i)B(i)B(i)分别表示分布 AAA和 BBB中第 iii个结果的概率。
计算 KL(P∥M)\text{KL}(P\|M)KL(P∥M):
- 正面(第1个结果):0.8×log2(0.80.5)=0.8×log2(1.6)≈0.8×0.678≈0.5420.8 \times \log_2\left( \frac{0.8}{0.5} \right) = 0.8 \times \log_2(1.6) \approx 0.8 \times 0.678 \approx 0.5420.8×log2(0.50.8)=0.8×log2(1.6)≈0.8×0.678≈0.542;
- 反面(第2个结果):0.2×log2(0.20.5)=0.2×log2(0.4)≈0.2×(−1.322)≈−0.2640.2 \times \log_2\left( \frac{0.2}{0.5} \right) = 0.2 \times \log_2(0.4) \approx 0.2 \times (-1.322) \approx -0.2640.2×log2(0.50.2)=0.2×log2(0.4)≈0.2×(−1.322)≈−0.264;
总和:KL(P∥M)≈0.542−0.264=0.278\text{KL}(P\|M) \approx 0.542 - 0.264 = 0.278KL(P∥M)≈0.542−0.264=0.278。
计算 KL(Q∥M)\text{KL}(Q\|M)KL(Q∥M):
- 正面(第1个结果):0.2×log2(0.20.5)=0.2×log2(0.4)≈0.2×(−1.322)≈−0.2640.2 \times \log_2\left( \frac{0.2}{0.5} \right) = 0.2 \times \log_2(0.4) \approx 0.2 \times (-1.322) \approx -0.2640.2×log2(0.50.2)=0.2×log2(0.4)≈0.2×(−1.322)≈−0.264;
- 反面(第2个结果):0.8×log2(0.80.5)=0.8×log2(1.6)≈0.8×0.678≈0.5420.8 \times \log_2\left( \frac{0.8}{0.5} \right) = 0.8 \times \log_2(1.6) \approx 0.8 \times 0.678 \approx 0.5420.8×log2(0.50.8)=0.8×log2(1.6)≈0.8×0.678≈0.542;
总和:KL(Q∥M)≈−0.264+0.542=0.278\text{KL}(Q\|M) \approx -0.264 + 0.542 = 0.278KL(Q∥M)≈−0.264+0.542=0.278。
(由于 PPP和 QQQ对称,KL(P∥M)=KL(Q∥M)\text{KL}(P\|M) = \text{KL}(Q\|M)KL(P∥M)=KL(Q∥M))
步骤3:计算Jensen-Shannon散度(JSD)
JSD的公式为:
JSD(P∥Q)=12[KL(P∥M)+KL(Q∥M)] \text{JSD}(P\|Q) = \frac{1}{2} \left[ \text{KL}(P\|M) + \text{KL}(Q\|M) \right] JSD(P∥Q)=21[KL(P∥M)+KL(Q∥M)]
代入结果:
JSD(P∥Q)≈12×(0.278+0.278)=0.278 \text{JSD}(P\|Q) \approx \frac{1}{2} \times (0.278 + 0.278) = 0.278 JSD(P∥Q)≈21×(0.278+0.278)=0.278
JSD的取值特点
- 当 PPP和 QQQ完全相同时(如 P=Q=[0.5,0.5]P = Q = [0.5, 0.5]P=Q=[0.5,0.5]),JSD(P∥Q)=0\text{JSD}(P\|Q) = 0JSD(P∥Q)=0(两个分布无差异);
- 当 PPP和 QQQ完全对立时(如 P=[1,0]P = [1, 0]P=[1,0]表示“必然正面”,Q=[0,1]Q = [0, 1]Q=[0,1]表示“必然反面”):
- 中间分布 M=[0.5,0.5]M = [0.5, 0.5]M=[0.5,0.5];
- KL(P∥M)=1×log2(10.5)=1\text{KL}(P\|M) = 1 \times \log_2\left( \frac{1}{0.5} \right) = 1KL(P∥M)=1×log2(0.51)=1,KL(Q∥M)=1×log2(10.5)=1\text{KL}(Q\|M) = 1 \times \log_2\left( \frac{1}{0.5} \right) = 1KL(Q∥M)=1×log2(0.51)=1;
- 此时 JSD(P∥Q)=12×(1+1)=1\text{JSD}(P\|Q) = \frac{1}{2} \times (1 + 1) = 1JSD(P∥Q)=21×(1+1)=1,为以2为底时的最大值。
Jensen-Shannon散度(JSD)是一种对称的分布差异度量,通过“中间分布M”把KL散度的不对称性修正。原理是通过“对称化”的方式,基于KL散度(Kullback-Leibler散度)来度量两个概率分布之间的差异,解决了KL散度不对称、取值范围无界的问题,更适合作为“分布距离”的度量。
拆解:
-
先解决KL散度的缺陷
KL散度(KL(P∣∣Q)KL(P||Q)KL(P∣∣Q))用于度量“用分布Q近似分布P”时的信息损失,但它有两个明显缺陷:- 不对称:KL(P∣∣Q)≠KL(Q∣∣P)KL(P||Q) \neq KL(Q||P)KL(P∣∣Q)=KL(Q∣∣P);
- 取值无界:可能为0到+∞+\infty+∞,不便于直接比较。
JSD的目标是改进这一点,让“分布差异”的度量更合理。
-
引入“中间分布”实现对称化
为了让度量对称,JSD先构造一个“中间分布”MMM,定义为两个分布PPP和QQQ的平均:
M=12(P+Q) M = \frac{1}{2}(P + Q) M=21(P+Q)
(即MMM中每个事件的概率是PPP和QQQ对应概率的平均值) -
用KL散度度量“两个分布到中间分布的距离”
计算PPP到MMM的KL散度(KL(P∣∣M)KL(P||M)KL(P∣∣M))和QQQ到MMM的KL散度(KL(Q∣∣M)KL(Q||M)KL(Q∣∣M)),这两个值分别代表PPP和QQQ与中间分布MMM的差异。 -
取平均得到对称的“总差异”
JSD定义为这两个KL散度的平均值:
JSD(P∣∣Q)=12[KL(P∣∣M)+KL(Q∣∣M)] JSD(P||Q) = \frac{1}{2}\left[ KL(P||M) + KL(Q||M) \right] JSD(P∣∣Q)=21[KL(P∣∣M)+KL(Q∣∣M)]由于MMM是PPP和QQQ的平均,KL(P∣∣M)KL(P||M)KL(P∣∣M)和KL(Q∣∣M)KL(Q||M)KL(Q∣∣M)的平均自然满足对称性:JSD(P∣∣Q)=JSD(Q∣∣P)JSD(P||Q) = JSD(Q||P)JSD(P∣∣Q)=JSD(Q∣∣P)。
-
性质决定其合理性
- 对称性:JSD(P∣∣Q)=JSD(Q∣∣P)JSD(P||Q) = JSD(Q||P)JSD(P∣∣Q)=JSD(Q∣∣P),符合“距离”的直观认知;
- 取值有界:当用以2为底的对数(单位为比特)时,0≤JSD≤10 \leq JSD \leq 10≤JSD≤1,便于量化比较(0表示分布完全相同,1表示分布完全对立);
- 更稳健:即使PPP或QQQ中存在概率为0的事件(如“不可能事件”),JSD也能稳定计算(而KL散度可能出现无穷大)。
“抛硬币”例子理解:
- 分布PPP(正面0.8,反面0.2)和QQQ(正面0.2,反面0.8);
- 中间分布MMM是两者的平均:(0.5,0.5)(相当于“公平硬币”);
- KL(P∣∣M)KL(P||M)KL(P∣∣M)度量“非公平硬币PPP与公平硬币MMM的差异”,KL(Q∣∣M)KL(Q||M)KL(Q∣∣M)度量“非公平硬币QQQ与公平硬币MMM的差异”;
- JSD是这两个差异的平均,最终结果≈0.278\approx 0.278≈0.278,表示PPP和QQQ的差异程度(越接近0越相似,越接近1越不同)。