PyTorch API
PyTorch API 是 PyTorch 提供的一套编程接口(Application Programming Interface),它允许开发者用 Python 或 C++ 编写深度学习程序,涵盖了从张量操作、自动求导,到构建神经网络、优化训练、加载数据等完整的机器学习/深度学习流程。
🔧 PyTorch API 包括的核心模块:
1. torch
:基础张量操作模块
-
类似 NumPy,但支持 GPU 和自动求导
-
常用函数:
torch.tensor()
,torch.arange()
,torch.mean()
,torch.matmul()
,torch.device()
等
2. torch.nn
:构建神经网络模型
-
提供了神经网络层(如
Linear
,Conv2d
,LSTM
)、激活函数(如ReLU
、Sigmoid
)等 -
用法:继承
nn.Module
构建自定义模型类
3. torch.autograd
:自动微分模块
-
自动构建计算图,支持
.backward()
自动求梯度 -
常用:
x.requires_grad=True
,y.backward()
,x.grad
4. torch.optim
:优化器模块
-
常见优化器如:
SGD
,Adam
,RMSprop
-
用法:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
5. torch.utils.data
:数据加载工具
-
用于创建自定义数据集和数据加载器(
Dataset
,DataLoader
) -
支持批量读取、打乱、并行加载等功能
6. torch.distributions
:概率分布
-
提供多种概率分布模型,用于采样、估计概率、强化学习中的策略等
7. torchvision
、torchaudio
、torchtext
(拓展库)
-
用于计算机视觉、音频处理、自然语言处理等领域,提供数据集、模型、预处理工具
🧠 PyTorch API 示例:手写一个简单神经网络
import torch
from torch import nn# 定义模型
class MLP(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(784, 256)self.output = nn.Linear(256, 10)def forward(self, x):x = torch.relu(self.hidden(x))return self.output(x)# 初始化模型和数据
model = MLP()
X = torch.rand((64, 784)) # batch_size=64, 输入维度784
y = model(X)
📚 官方文档
PyTorch API 文档官网(中文/英文)提供所有模块、类、函数的详细说明:
-
英文版:Page Redirection
-
中文版(可能略旧):https://pytorch.apachecn.org/