《PEFLL: Personalized Federated Learning by Learning to Learn》——论文阅读
研究背景
个性化联邦学习(pFL)试图为每个客户端训练专属模型,但现有方法普遍存在以下问题:
新客户端需要本地微调或训练,延迟高、计算重;
对低数据客户端不友好,容易过拟合;
通信开销大,客户端需多次与服务器交互;
扩展性差,如客户端数量巨大时难以管理。
核心思想:
学习一个“学习算法”,即通过两个神经网络协同工作:
- 嵌入网络(Embedding Network):
将客户端的数据映射为一个低维向量,捕捉该客户端的数据分布特征。以客户端本地数据样本进行输入,转化为固定维度的向量进行输出。
如果两个客户端的数据分布相似,它们的向量也会很接近,从而让超网络为它们生成相似的模型。
- 超网络(Hypernetwork):
以嵌入向量作为输入,一次性输出该客户端的完整个性化模型参数(无需再训练)。
客户端数据 → 嵌入网络 → 向量
向量→ 超网络 → 个性化模型参数 θ
客户端直接使用 θ,无需训练
【嵌入网络负责“看懂”客户端的数据分布,超网络负责“定制”出专属于这个客户端的模型。】
步骤:
预测阶段:
服务器每轮随机选一批客户端,把当前嵌入网络参数广播给这几个客户端。
客户端用本地数据计算嵌入向量,回传服务器;
服务器用超网络为每个客户端生成模型,再把模型传给客户端;
客户端用本地数据训练模型几步(知道之后要往哪个方向进行调节);
客户端将梯度传给服务器;
服务器使用链式法则反向传播 更新嵌入网络和超网络参数。
服务器将更新的个性化模型参数分给所抽选的客户端
用所有客户端的数据来“教会”服务器端的超网络,让它以后仅凭任何客户端上传的嵌入向量,就能立刻吐出专属模型参数。
推理阶段:
新客户端用本地数据计算嵌入向量(一次前向);
服务器用超网络生成个性化模型,并传给客户端;
客户端直接使用该模型,无需训练。
一个传来传去 比较绕的个性化联邦模型