Python 数据科学与可视化工具箱 - 数组形状操作:reshape(), flatten()
接下来,我们将学习如何改变数组的形状,这是数据预处理和模型输入准备中至关重要的一步。本篇将详细介绍两个核心的形状操作方法:reshape()
和 flatten()
。
1. 为什么需要改变数组的形状?
在数据科学和机器学习中,数据很少以我们所需的精确形状出现。例如:
- 数据预处理:你可能有一个包含 100 张 28x28 像素的灰度图像的数组,它的形状是
(100, 28, 28)
。但在将这些图像送入某些机器学习模型(如全连接神经网络)之前,你需要将每张图像展平为长度为 784 的一维向量,从而得到一个(100, 784)
的数据集。 - 模型输入:许多模型期望特定的输入形状。例如,一个 LSTM 模型可能需要一个
(batch_size, time_steps, features)
形状的输入,而你的数据可能只是一个简单的(batch_size, features)
矩阵。你需要添加一个维度来适配模型。 - 可视化:你可能需要将一个一维数组重塑成一个二维矩阵,以便用热力图或图像的形式进行可视化。
改变数组形状的操作在 NumPy 中高效且灵活,是数据处理管道中不可或缺的一环。
2. reshape()
:重塑数组
reshape()
方法用于在不改变数组数据的情况下,返回一个具有新形状的数组。它的核心思想是:新形状下的元素总数必须与原始数组的元素总数相同。
- 功能: 为数组赋予一个新形状,并返回新数组的视图(通常是)。
- 语法:
arr.reshape(new_shape)
或np.reshape(arr, new_shape)
- 重要特性:
- 如果可能,
reshape()
会返回一个视图(View),这意味着新数组和原始数组共享底层数据。如果你修改其中一个,另一个也会受影响。 - 如果新形状的维度顺序与原始数组的内存布局不兼容,
reshape()
可能会返回一个副本(Copy)。
- 如果可能,
示例与说明:
import numpy as np# 创建一个一维数组
arr = np.arange(12) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
print("原始数组:", arr)
print("原始形状:", arr.shape)# 将一维数组重塑为 2x6 的二维数组
arr_reshaped = arr.reshape(2, 6)
print("\n重塑为 2x6 矩阵:\n", arr_reshaped)
print("新形状:", arr_reshaped.shape)# 将一维数组重塑为 3x2x2 的三维数组
arr_3d = arr.reshape(3, 2, 2)
print("\n重塑为 3x2x2 三维数组:\n", arr_3d)
print("新形状:", arr_3d.shape)# 使用 -1 自动推断维度大小
# 当你只确定一个维度的大小,而另一个维度由元素总数决定时,-1 很有用。
# 元素总数是 12,如果一个维度是 3,那么另一个维度一定是 12 / 3 = 4
arr_auto_reshaped = arr.reshape(3, -1)
print("\n使用 -1 自动推断形状 (3x4):\n", arr_auto_reshaped)
print("新形状:", arr_auto_reshaped.shape)# 视图示例
# 修改重塑后的数组,原始数组也会改变
arr_reshaped[0, 0] = 99
print(