当前位置: 首页 > news >正文

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo

使用 PyTorch 和 Horovod 来编写一个简单的分布式训练 demo,可以帮助你理解如何在多GPU或多节点环境中高效地训练深度学习模型。Horovod 是 Uber 开发的一个用于分布式训练的框架,它支持 TensorFlow、Keras、PyTorch 等多个机器学习库。下面是一个基于 PyTorch 的简单例子,演示了如何用 Horovod 进行分布式训练。

安装依赖

首先确保你已经安装了 PyTorch 和 Horovod。你可以通过 pip 或者 conda 来安装这些包。对于 Horovod,推荐使用 MPI(Message Passing Interface)进行通信,因此你也需要安装 MPI 和相应的开发工具。

pip install torch torchvision horovod

或者如果你使用的是 Anaconda:

conda install pytorch torchvision -c pytorch
horovodrun --check
# 如果没有安装 horovod, 可以使用以下命令安装:
pip install horovod[pytorch]

编写 PyTorch + Horovod 代码

创建一个新的 Python 文件 train.py,然后添加如下代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import horovod.torch as hvd
from torchvision import datasets, transforms# 初始化 Horovod
hvd.init()# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 如果有 GPU 可用,则使用 GPU
if torch.cuda.is_available():torch.cuda.set_device(hvd.local_rank())# 加载数据集 (以 MNIST 数据集为例)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, transform=transform)# 分布式采样器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=1000, sampler=val_sampler)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout2(x)x = self.fc2(x)output = nn.functional.log_softmax(x, dim=1)return outputmodel = Net()# 如果有 GPU 可用,则将模型转移到 GPU 上
if torch.cuda.is_available():model.cuda()# 定义损失函数和优化器,并应用 Horovod 的 DistributedOptimizer 包装
optimizer = optim.Adam(model.parameters(), lr=0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer,named_parameters=model.named_parameters(),op=hvd.Average)# 损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if torch.cuda.is_available():data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_sampler),100. * batch_idx / len(train_loader), loss.item()))# 验证模型
def validate():model.eval()validation_loss = 0correct = 0with torch.no_grad():for data, target in val_loader:if torch.cuda.is_available():data, target = data.cuda(), target.cuda()output = model(data)validation_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()validation_loss /= len(val_loader.dataset)accuracy = 100. * correct / len(val_loader.dataset)print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(validation_loss, correct, len(val_loader.dataset), accuracy))# 调用训练和验证函数
for epoch in range(1, 4):train_sampler.set_epoch(epoch)train(epoch)validate()# 广播模型状态从rank 0到其他进程
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

运行代码

要运行这段代码,你需要使用 horovodrun 命令来启动多个进程。例如,在单个节点上的 4 个 GPU 上运行该脚本,可以这样做:

horovodrun -np 4 -H localhost:4 python train.py

这将在本地主机上启动四个进程,每个进程都会占用一个 GPU。如果你是在多台机器上运行,你需要指定每台机器的地址和可用的 GPU 数量。

请注意,这个例子是一个非常基础的实现,实际应用中可能还需要考虑更多的细节,比如更复杂的模型结构、数据预处理、超参数调整等。

http://www.lryc.cn/news/499305.html

相关文章:

  • SQL复杂查询功能介绍及示例
  • shell基础用法
  • C#设计模式--策略模式(Strategy Pattern)
  • 【opencv入门教程】15. 访问像素的十四种方式
  • 【MySQL调优】如何进行MySQL调优?从参数、数据建模、索引、SQL语句等方向,三万字详细解读MySQL的性能优化方案(2024版)
  • 根据html的段落长度设置QtextBrowser的显示内容,最少显示一个段落
  • 基于Huffman编码的GPS定位数据无损压缩算法
  • php:完整部署Grid++Report到php项目,并实现模板打印
  • C标签和 EL表达式的在前端界面的应用
  • Linux絮絮叨(四) 系统目录结构
  • Java基于SpringBoot的网上订餐系统,附源码
  • 《Java核心技术I》死锁
  • 【Windows11系统局域网共享文件数据】
  • MCU、ARM体系结构,单片机基础,单片机操作
  • 在办公室环境中用HMD替代传统显示器的优势
  • ssm 多数据源 注解版本
  • selenium常见接口函数使用
  • STM32F103单片机使用STM32CubeMX新建IAR工程步骤
  • 刷题重开:找出字符串中第一个匹配项的下标——解题思路记录
  • product/admin/list?page=0size=10field=jancodevalue=4562249292272
  • 人工智能机器学习无监督学习概念及应用详解
  • APM装机教程(五):测绘无人船
  • 微信小程序 运行出错 弹出提示框(获取token失败,请重试 或者 请求失败)
  • IDEA的service窗口中启动类是灰色且容易消失
  • R中利用ggplot2绘制气泡图
  • CID引流电商
  • 在google cloud虚拟机上配置anaconda虚拟环境简单教程
  • windows下用vs搭配clang一起生成抽象语法树
  • 输入法:点三下输入一个汉字
  • URL访问网址的全过程