当前位置: 首页 > news >正文

PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数说明
input输入张量
start_dim开始展平的维度(包含该维)
end_dim结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

import torchx = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])  # shape = (2, 2, 2)out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数说明
view()更底层、需要张量是连续的,手动指定形状
flatten()更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

import torch.nn as nnclass MyCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(16, 10)def forward(self, x):x = self.conv(x)x = self.pool(x)              # shape: (N, 16, 1, 1)x = torch.flatten(x, 1)       # shape: (N, 16)x = self.fc(x)return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):def __init__(self):super(FlattenCNN, self).__init__()self.conv = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]nn.ReLU(),nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]nn.Conv2d(16, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8])self.fc = nn.Sequential(nn.Linear(32 * 8 * 8, 128),nn.ReLU(),nn.Linear(128, 10)              # CIFAR-10 共 10 类)def forward(self, x):x = self.conv(x)x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batchx = self.fc(x)return x# ==== 2. 准备数据 ====
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ==== 4. 训练过程 ====
def train(model, loader, epochs):model.train()for epoch in range(epochs):total_loss = 0.0for images, labels in loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(loader)print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数特点说明
flatten()高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view()底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape()更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

x = torch.randn(32, 3, 64, 64)  # batch of images# flatten
f1 = torch.flatten(x, 1)# view
f2 = x.view(32, -1)# reshape
f3 = x.reshape(32, -1)print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量try:y.view(-1)  # 会报错
except RuntimeError as e:print("view error:", e)print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

import torch
import timex = torch.randn(1024, 512, 28, 28)# 保证是连续的
x_contig = x.contiguous()N = 1000def benchmark(op, name):torch.cuda.synchronize()start = time.time()for _ in range(N):_ = op(x_contig)torch.cuda.synchronize()end = time.time()print(f"{name}: {(end - start)*1000:.2f} ms")benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景推荐方式原因
模型中展平 CNN 输出flatten()简洁、安全,尤其在复杂网络中
确保连续张量、追求速度view()性能最佳
张量可能非连续reshape()自动处理不连续情况,代码更鲁棒

六、小结

用法效果
torch.flatten(x)将所有维展平成一维
torch.flatten(x, 1)保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2)展平指定维度区间

http://www.lryc.cn/news/602996.html

相关文章:

  • 【代码解读】通义万相最新视频生成模型 Wan 2.2 实现解析
  • AR技术赋能工业设备维护:效率与智能的飞跃
  • 一个典型的微控制器MCU包含哪些模块?
  • 安宝特方案丨AI算法能力开放平台:适用于人工装配质检、点检、实操培训
  • Java学习-----如何创建线程
  • 基于黑马教程——微服务架构解析(二):雪崩防护+分布式事务
  • Qt:盒子模型的理解
  • 2025.7.28总结
  • 嵌入式分享合集186
  • JavaScript 回调函数讲解_callback
  • 关于xshell的一些基本内容讲解
  • tsc命令深入全面讲解
  • jQuery 最新语法大全详解(2025版)
  • python对象的__dict__属性详解
  • 防水医用无人机市场报告:现状、趋势与洞察
  • Java 笔记 serialVersionUID
  • 分布式IO详解:2025年分布式无线远程IO采集控制方案选型指南
  • 生物信息学数据技能-学习系列001
  • 秒级构建消息驱动架构:描述事件流程,生成 Spring Cloud Stream+RabbitMQ 代码
  • Java 大视界 -- Java 大数据在智能安防入侵检测系统中的多源数据融合与误报率降低策略(369)
  • 分布式高可用架构核心:复制、冗余与生死陷阱——从主从灾难到无主冲突的避坑指南
  • redis getshell的三种方法
  • 从释永信事件看“积善“与“积恶“的人生辩证法
  • CMake、CMakeLists.txt 基础语法
  • CTF-Web学习笔记:信息泄露篇
  • docker 入门,运行上传自己的首个镜像
  • 降低焊接机器人保护气体消耗的措施
  • Docker 部署 Supabase并连接
  • 记录自己第n次面试(n>3)
  • DAY-13 数组与指针