当前位置: 首页 > news >正文

pytorch中一些最基本函数和类

1.Tensor操作

Tensor是PyTorch中最基本的数据结构,类似于NumPy的数组,但可以在GPU上运行加速计算。

  示例:创建和操作Tensor

import torch# 创建一个零填充的Tensor
x = torch.zeros(3, 3)
print(x)# 加法操作
y = torch.ones(3, 3)
z = x + y
print(z)# 在GPU上创建Tensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.zeros(3, 3, device=device)
print(x)
运行结果:

2. nn.Module和自定义模型

  nn.Module是PyTorch中定义神经网络模型的基类,所有的自定义模型都应该继承自它。

示例:定义一个简单的全连接神经网络模型

import torch
import torch.nn as nn# 自定义模型类
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(10, 5)  # 线性层:输入维度为10,输出维度为5def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleNet()
print(model)
运行结果:

3. DataLoader和Dataset

 DataLoader用于批量加载数据Dataset定义了数据集的接口,自定义数据集需继承自它。

示例:加载自定义数据集

import torch
from torch.utils.data import Dataset, DataLoader# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, data, targets):self.data = dataself.targets = targetsdef __len__(self):return len(self.data)def __getitem__(self, index):x = self.data[index]y = self.targets[index]return x, y# 假设有一些数据和标签
data = torch.randn(100, 10)  # 100个样本,每个样本10维
targets = torch.randint(0, 2, (100,))  # 100个随机标签,0或1# 创建数据集实例
dataset = CustomDataset(data, targets)# 创建数据加载器
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 打印一个batch的数据
for batch in dataloader:inputs, labels = batchprint(inputs.shape, labels.shape)break
运行结果: 

4. 优化器和损失函数

   优化器用于更新模型参数以减少损失,损失函数用于计算预测值与实际值之间的差异。

示例:使用优化器和损失函数

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型(假设已定义好)
model = SimpleNet()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向传播、损失计算、反向传播和优化过程请参考前面完整示例的训练循环部分。
运行结果: 

5. nn.functional中的函数

  nn.functional提供了各种用于构建神经网络的函数,如激活函数池化操作等。

示例:使用ReLU激活函数

import torch
import torch.nn.functional as F# 创建一个Tensor
x = torch.randn(3, 3)# 使用ReLU激活函数
output = F.relu(x)
print(output)
运行结果: 

http://www.lryc.cn/news/401618.html

相关文章:

  • 排序——归并排序及排序章节总结
  • python的readline()和readlines()
  • 【ARM】使用JasperGold和Cadence IFV科普
  • 深入探讨极限编程(XP):技术实践与频繁发布的艺术
  • 【代码随想录_Day30】1049. 最后一块石头的重量 II 494. 目标和 474.一和零
  • 【时时三省】tessy 集成测试:小白入门指导手册
  • 通过vagrant与VirtualBox 创建虚拟机
  • 第13章 更多的结构化命令《Linux命令行与Shell脚本编程大全笔记》
  • 【计算机网络】学习指南及导论
  • 安装mitmproxy失败
  • 安装adb和常用命令
  • C++ 几何计算库
  • 云动态摘要 2024-07-16
  • 数仓工具—Hive基础之临时表及示例
  • 机体坐标系和导航坐标系
  • 软件测试——web单功能测试
  • django-ckeditor富文本编辑器
  • 鸿蒙模拟器(HarmonyOS Emulator)Beta申请审核流程
  • VUE:跨域配置代理服务器
  • Redis实战—附近商铺、用户签到、UV统计
  • 小程序里面使用vant ui中的vant-field组件,如何使得输入框自动获取焦点
  • Html_Css问答集(12)
  • 【C语言】条件运算符详解 - 《 A ? B : C 》
  • 乘积量化pq:将高维向量压缩 97%
  • 解决一下git clone失败的问题
  • 【 香橙派 AIpro评测】烧系统运行部署LLMS大模型跑开源yolov5物体检测并体验Jupyter Lab AI 应用样例(新手入门)
  • Azure Repos 仓库管理
  • Day71 代码随想录打卡|回溯算法篇---全排列
  • 开源科学工程技术软件
  • 甄选范文“论软件维护方法及其应用”软考高级论文,系统架构设计师论文