当前位置: 首页 > news >正文

在 Mac M2 上安装 PyTorch 并启用 MPS 加速的详细教程与性能对比

1. 安装torch

在官网上可以查看安装教程,Start Locally | PyTorch

作者安装了目前最新的torch版本2.5.1,需要提前安装python3.9及以上版本,作者python版本是python3.11最新版本 

使用conda安装torch,在终端进入要安装的环境,执行如下命令即可,值得一提的是,安装torch的前提条件是需要事先安装对应版本的python,以及annoconda

conda install pytorch torchvision -c pytorch

执行完如上命令后就会出现如下画面,需要等待几分钟,直到安装完毕

2. 安装MPS

使用conda安装mps

conda install torch torchvision torchaudio

3 安装是否成功测试

import torch
# 查看 torch安装是否成功 并查看其版本
print(torch.__version__)
# 查看 mps是否安装成功 是否可用
print(torch.backends.mps.is_available())
# 检查 GPU 是否可用
print(torch.cuda.is_available())  # 对于 MPS,返回 False 是正常的
print(torch.backends.mps.is_available())  # 应该返回 True
# 获取 MPS 设备
mps_device = torch.device("mps")
print(mps_device)  # 输出 "mps"

执行如上代码,能够成功打印出torch版本,证明第一章节的torch安装成功,如果能打印出True证明MPS可用,至于其中的一个False是cuda是否可用,因为作者是Mac电脑,没有安装显卡所以并无法安装cuda加速,固然为false

4 加速对比

总的来说,模型越复杂,其MPS加速越明显,如果模型太简单,只需要几秒钟就能跑完的话,MPS加速反而不如CPU,因为MPS要有一些准备工作,把数据放入图显核心里去,如果算法太简单或者数据量太少,结果运行加速节约的时间还不如数据准备的时间长,看起来就会觉得MPS反而需要更多时间来运行。

如下是作者的测试代码

import torch
import torch.nn as nn
import torch.optim as optim
import time# 设置训练参数
input_size = 4096  # 输入特征数
hidden_size = 1024  # 隐藏层神经元数
output_size = 10  # 输出类别数(例如 10 类)
num_epochs = 50  # 训练轮数
batch_size = 64  # 批量大小
learning_rate = 0.01  # 学习率# 定义一个简单的全连接神经网络
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 函数:训练模型并记录训练时间
def train_model(device, num_epochs):# 创建数据集num_samples = 100000  # 数据集样本数量x_train = torch.randn(num_samples, input_size).to(device)y_train = torch.randint(0, output_size, (num_samples,)).to(device)# 模型、损失函数和优化器model = SimpleNN(input_size, hidden_size, output_size).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 开始计时start_time = time.time()# 训练循环for epoch in range(num_epochs):for i in range(0, num_samples, batch_size):# 获取当前批量数据inputs = x_train[i:i+batch_size]labels = y_train[i:i+batch_size]# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 结束计时end_time = time.time()# 返回训练时间return end_time - start_time# 主程序
if __name__ == "__main__":# 设备列表devices = {"CPU": torch.device("cpu"),"MPS": torch.device("mps") if torch.backends.mps.is_available() else None,}# 分别测试 CPU 和 MPSresults = {}for device_name, device in devices.items():if device is None:print(f"\nSkipping {device_name} as it is not available.")continueprint(f"\nTraining on {device_name}...")training_time = train_model(device, num_epochs)results[device_name] = training_timeprint(f"Training time on {device_name}: {training_time:.2f} seconds")# 打印对比结果print("\n--- Training Time Comparison ---")for device_name, time_taken in results.items():print(f"{device_name}: {time_taken:.2f} seconds")

本人运行的机器是Mac Mini M2(8+10)16G+1T ,

3.1 CPU和GPU占用

在使用CPU运行时, 明显看到8核心的CPU,程序几乎占用了4核心一半,GPU没有使用

 在使用MPS运行时,CPU占比下降到较低水平,开始启用GPU运行,10核心的图显也仅仅使用了1颗,感觉加速不是特别明显

3.2 温度对比 

使用CPU运行时,常年保持40度以下的CPU温度也飙升到了65度左右,及时如此也仅是window电脑静默状态的温度了 

 使用MPS运行时,温度稍有回落,在50度左右

3.3 运行时间

如图所示,MPS加速仅仅比CPU花费时间减少一半左右,说实话不是特别满意,和cuda的加速还是有一定差距

http://www.lryc.cn/news/509270.html

相关文章:

  • 生成式人工智能在生产型企业中的应用
  • Linux逻辑卷管理
  • 机器人加装电主轴【铣削、钻孔、打磨、去毛刺】更高效
  • opencv sdk for java中提示无stiching模块接口的问题
  • 今天最新早上好问候语精选大全,每天问候,相互牵挂,彼此祝福
  • 五种IO模型- 阻塞IO、非阻塞IO、多路复用IO、信号驱动IO以及异步IO
  • Vscode GStreamer插件开发环境配置
  • flask基础
  • Java日志框架:log4j、log4j2、logback
  • 鸿蒙-expandSafeArea使用
  • 【es6复习笔记】Spread 扩展运算符(8)
  • 第22天:信息收集-Web应用各语言框架安全组件联动系统数据特征人工分析识别项目
  • 后端-redis
  • 开发场景中Java 集合的最佳选择
  • golangci-lint安装与Goland集成
  • 金仓数据库安装-Kingbase v9-centos
  • 条款6:auto推导若非己愿,使用显式类型初始化惯用法
  • 蓝桥杯物联网开发板硬件组成
  • 视频汇聚融合云平台Liveweb一站式解决视频资源管理痛点
  • (aaai2025) FD2-Net: Frequency-Driven Feature Decomposition Network
  • 深度学习之目标检测——RCNN
  • 2014年IMO第3题
  • 国高材服务 | 高分子结晶动力学表征——高低温热台偏光显微镜
  • 跨站请求伪造之基本介绍
  • Hadoop集群(HDFS集群、YARN集群、MapReduce​计算框架)
  • 单元测试(UT,C++版)经验总结(gtest+gmock)
  • Mysql高级部分总结(二)
  • 纠正一下网络管理
  • homebrew,gem,cocoapod 换源,以及安装依赖
  • Java字符串的|分隔符转List实现方案