当前位置: 首页 > news >正文

【联邦学习——手动搭建简易联邦学习】

1. 目的

用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。

2. 代码

2.1 本地模型计算梯度更新

# 比较训练前后的参数变化
def compare_weights(new_model, old_model):weight_updates = {}for layer_name, params in new_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)return weight_updates

测试代码如下:
有意思的点在于我获得了update = model2-model1
但是我去计算model1+update==model2的时候发现不相等
最后思考了一下可能是在这个计算的过程中存在精度的丢失

import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoaderdef weight_init(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)# 也可以判断是否为conv2d,使用相应的初始化方式elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 是否为批归一化层elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if __name__ == '__main__':# 建立model1和model2作为训练前后的模型model1 = models.get_model("resnet18")model1.apply(weight_init)model2 = models.get_model("resnet18")model2.apply(weight_init)weight_updates = compare_weights(model2,model1)# 创建一个临时的模型状态字典,用于存储更新后的参数updated_params = model1.state_dict().copy()for layer_name, update in weight_updates.items():# 确保该层存在于model1中且形状匹配,避免错误if layer_name in updated_params and update.shape == updated_params[layer_name].shape:# 直接相加更新参数updated_params[layer_name] += updateelse:print(f"Warning: Layer {layer_name} not found or shape mismatch, skipping update.")# 将更新后的参数加载回model1model1.load_state_dict(updated_params)for layer_name,params in model1.state_dict().items():ts = params-model2.state_dict().get(layer_name)# 很重要,这里应该是我们在做的时候会有一点模型精度上的损失,所以不能够计算这里等于0if torch.sum(ts).item()>1e-6:print(f"{layer_name}更新后与原有的不匹配,差距为")else:print(f"{layer_name}更新后与原有的匹配")

2.2 客户端代码

import numpy as np
import torch.utils.data
from tqdm import tqdm'''
conf 配置文件
model 模型
train_dataset 数据集
class_ratios 从数据集中筛选出一部分 class_ratios = {0: 0.5, 1: 0.5,..., 8: 0.5, 9: 0.5}
id 客户端的标识
'''class Client(object):def __init__(self,conf,model,device,train_loader,id=1):self.client_id = id                     # 客户端IDself.conf = conf                        # 配置文件self.local_model = model                # 客户端本地模型self.train_loader = train_loader        # 训练数据的迭代器,需要训练的数据已经在里面了self.grad_update = dict()               # 本地训练完之后的梯度更新self.weight = conf['weight']            # 全局模型梯度更新时的权重self.device = device                    # 训练的设备self.local_model.to(self.device)        # 将模型放入训练设备def train(self, model):self._before_train(model)self._local_train()self._after_train(model)def _before_train(self, model):self._load_global_model(model)# 用服务器模型来覆盖本地模型def _load_global_model(self,model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())def _local_train(self):# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()loss = 0for epoch in range(self.conf['local_epochs']):for batch in tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.conf['local_epochs']}"):data, target = batch# 放入相应的设备data = data.to(self.device)target = target.to(self.device)# 梯度清零optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)# 反向传播loss.backward()optimizer.step()print(f"Client{self.client_id}----Epoch {epoch} done.Loss {loss}")def _after_train(self,model):self._cal_update_weights(model)def _cal_update_weights(self, old_model):weight_updates = dict()for layer_name, params in self.local_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)# 更新梯度模型的权重self.grad_update = weight_updates

2.3 服务器代码

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsfrom utils.CommonUtils import copy_model_params# 服务端
class Server(object):def __init__(self, conf, eval_dataset, device):self.conf = conf# 全局老模型self.old_model = models.get_model(self.conf["model_name"])# 全局的新模型self.global_model = models.get_model(self.conf["model_name"])# 创建时保持新老模型的参数是一致的copy_model_params(self.old_model,self.global_model)# 根据客户端上传的梯度进行排列组合,用于测量贡献度的模型self.sub_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = []  # 保存accuracy的数组self.loss_history = []  # 保存loss的数组self.device = deviceself.old_model.to(device)self.global_model.to(device)self.sub_model.to(device)# 模型重构def model_aggregate(self, clients, target_model):if target_model == self.global_model:print("++++++++全局模型更新++++++++")# 更新一下老模型参数copy_model_params(self.old_model,self.global_model)else:print("========子模型重构========")sum_weight = 0# 计算总的权重for client in clients:sum_weight += client.weight# 将old_model的模型参数赋值给sub_modelcopy_model_params(self.sub_model, self.old_model)# 初始化一个空字典来累积客户端的模型更新aggregated_updates = {}# 遍历每个客户端for client in clients:# 根据客户端的权重比例聚合更新for name, update in client.grad_update.items():if name not in aggregated_updates:aggregated_updates[name] = update * client.weight / sum_weightelse:aggregated_updates[name] += update * client.weight / sum_weight# 应用聚合后的更新到sub_modelfor name, param in target_model.state_dict().items():if name in aggregated_updates:param.copy_(param + aggregated_updates[name])  # 累加更新到当前层参数上# 定义模型评估函数def model_eval(self,target_model):target_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]# 放入和模型对应的设备data = data.to(self.device)target = target.to(self.device)# 模型预测output = target_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)if target_model == self.global_model:print(f"++++++++全局模型评估++++++++acc:{acc}  loss:{total_l}")else:print(f"========子模型评估========acc:{acc}  loss:{total_l}")return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))

2.4 Utils

def copy_model_params(target_model, source_model):for name, param in source_model.state_dict().items():target_model.state_dict()[name].copy_(param.clone())

3. 运行测试代码

import jsonimport torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision import transforms, datasets
from client.Client import Client
from server.Server import Serverwith open("../conf/client1.json",'r') as f:conf = json.load(f)with open("../conf/server1.json",'r') as f:serverConf = json.load(f)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform)
eval_dataset = datasets.CIFAR10(root='../data/', train=False, download=True,transform=transform)# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=2)# 计算数据集长度
total_samples = len(train_dataset)
# 确保可以平均分配,否则需要调整逻辑以处理余数
assert total_samples % 2 == 0, "数据集样本数需为偶数以便完全平分"# 分割点
split_point = total_samples // 2# 创建两个子集
train_dataset_first_half = Subset(train_dataset, range(0, split_point))
train_dataset_second_half = Subset(train_dataset, range(split_point, total_samples))# 然后为每个子集创建DataLoader
batch_size = 32train_loader_first_half = DataLoader(train_dataset_first_half, shuffle=True, batch_size=batch_size, num_workers=2)
train_loader_second_half = DataLoader(train_dataset_second_half, shuffle=True, batch_size=batch_size, num_workers=2)# 检查CUDA是否可用
if torch.cuda.is_available():device = torch.device("cuda")  # 如果CUDA可用,选择GPU
else:device = torch.device("cpu")   # 如果CUDA不可用,选择CPUlocal_model = models.get_model("resnet18")
local_model2 = models.get_model("resnet18")server = Server(serverConf,eval_dataset,device)
client1 = Client(conf, local_model, device,train_loader_first_half, 1)
client2 = Client(conf, local_model2, device,train_loader_second_half, 2)
for i in range(2):client1.train(server.global_model)client2.train(server.global_model)server.model_aggregate([client1,client2], server.global_model)server.model_eval(server.global_model)server.model_aggregate([client1], server.sub_model)server.model_eval(server.sub_model)server.model_aggregate([client2], server.sub_model)server.model_eval(server.sub_model)
http://www.lryc.cn/news/348484.html

相关文章:

  • Springboot项目如何创建单元测试
  • Win10 如何同时保留两个CUDA版本并自由切换使用
  • 实验室纳新宣讲会(java后端)
  • class常量池、运行时常量池和字符串常量池的关系
  • Java | Leetcode Java题解之第88题合并两个有序数组
  • 韵搜坊(全栈)-- 前后端初始化
  • Android:资源的管理,Glide图片加载框架的使用
  • conll-2012-formatted-ontonotes-5.0中文数据格式说明
  • SpringBoot集成Seata分布式事务OpenFeign远程调用
  • 视觉检测系统,是否所有产品都可以进行视觉检测?
  • 通过金山和微软虚拟打印机转换PDF文件,流程方法及优劣对比
  • 采用java+B/S开发的全套医院绩效考核系统源码springboot+mybaits 医院绩效考核系统优势
  • 驱动开发-用户空间和内核空间数据传输
  • 【408精华知识】速看!各种排序的大总结!
  • 【STM32 |程序实例】按键控制、光敏传感器控制蜂鸣器
  • Spring boot使用websocket实现在线聊天
  • 品牌设计理念和logo设计方法
  • Python | Leetcode Python题解之第88题合并两个有序数组
  • vscode新版本remotessh服务端报`GLIBC_2.28‘ not found解决方案
  • 盘他系列——oj!!!
  • 洛谷 P2657 [SCOI2009] windy 数 题解 数位dp
  • Python爬虫入门:网络世界的宝藏猎人
  • 【NodeMCU实时天气时钟温湿度项目 6】解析天气信息JSON数据并显示在 TFT 屏幕上(心知天气版)
  • 重构四要素:目的、对象、时机和方法
  • 基于Echarts的大数据可视化模板:服务器运营监控
  • Python3 笔记:Python的常量
  • 【Linux】自动化构建工具make/Makefile和git介绍
  • C语言—关于字符串(编程实现部分函数功能)
  • picoCTF-Web Exploitation-Trickster
  • SSH 免密登录,设置好仍然需要密码登录解决方法