NumPy Squeeze vs Reshape:数组维度变换解析 – wiki基地


NumPy Squeeze vs Reshape:数组维度变换解析

在数据科学、机器学习和科学计算的领域,NumPy (Numerical Python) 库是不可或缺的基石。它提供了强大的 N 维数组对象(ndarray)以及一系列用于高效操作这些数组的函数。其中,数组的维度(或称为轴,axis)和形状(shape)是核心概念。数据的表示方式往往直接影响算法的效率和兼容性。因此,能够灵活地改变数组的形状和维度,是 NumPy 应用中的一项基本功。

在众多维度变换工具中,numpy.squeeze()numpy.reshape() (或其等效的 ndarray.reshape() 方法) 是最常用也时常引起混淆的两个函数。虽然它们都用于改变数组的形状,但其工作方式、适用场景和内在逻辑有着本质的区别。本文将深入探讨 squeezereshape 的功能、差异、使用场景和注意事项,帮助你彻底理解并熟练运用这两个强大的维度变换工具。

一、 NumPy 数组的维度与形状:基础回顾

在深入 squeezereshape 之前,我们先快速回顾一下 NumPy 数组的维度和形状。

  • 维度 (Dimension/Axis): NumPy 数组可以有多个维度。一个 1 维数组(向量)只有一个轴,一个 2 维数组(矩阵)有两个轴(通常是行和列),一个 3 维数组(张量)有三个轴,以此类推。维度的数量被称为数组的 ndim 属性。
  • 形状 (Shape): 形状是一个元组 (tuple),表示数组在每个维度上的大小(元素数量)。例如,一个形状为 (3, 4) 的数组是一个 2 维数组,有 3 行 4 列。形状信息可以通过数组的 shape 属性获取。
  • 元素总数 (Size): 数组中所有元素的总数量,等于其形状元组中所有元素的乘积。可以通过 size 属性获取。

“`python
import numpy as np

1D array (vector)

arr1d = np.array([1, 2, 3, 4])
print(f”arr1d: {arr1d}”)
print(f” ndim: {arr1d.ndim}”) # Output: 1
print(f” shape: {arr1d.shape}”) # Output: (4,)
print(f” size: {arr1d.size}”) # Output: 4

2D array (matrix)

arr2d = np.array([[1, 2, 3], [4, 5, 6]])
print(f”\narr2d:\n{arr2d}”)
print(f” ndim: {arr2d.ndim}”) # Output: 2
print(f” shape: {arr2d.shape}”) # Output: (2, 3)
print(f” size: {arr2d.size}”) # Output: 6

3D array (tensor)

arr3d = np.zeros((1, 4, 1, 5)) # Example with singleton dimensions
print(f”\narr3d initial shape: {arr3d.shape}”) # Output: (1, 4, 1, 5)
print(f” ndim: {arr3d.ndim}”) # Output: 4
print(f” size: {arr3d.size}”) # Output: 20
“`

理解维度和形状至关重要,因为许多 NumPy 操作(如广播、索引、数学运算)都依赖于数组的形状。有时,操作的结果可能会引入不必要的维度(特别是大小为 1 的维度,也称为“单一维度”或“冗余维度”),或者我们需要将数据重组成特定的形状以适应某个算法或函数库(例如,机器学习框架通常对输入数据形状有严格要求)。这时,squeezereshape 就派上了用场。

二、 numpy.squeeze():移除单一维度

numpy.squeeze() 函数的核心目标非常明确:从数组的形状中移除大小为 1 的维度(轴)。它就像是给数组“挤压”掉那些没有实际数据内容的维度。

语法:

python
numpy.squeeze(a, axis=None)

  • a: 输入的 NumPy 数组。
  • axis: 一个可选参数,指定要移除的轴。
    • 如果 axisNone (默认值),squeeze 会尝试移除 所有 大小为 1 的维度。
    • 如果 axis 是一个整数或整数元组,squeeze 只会尝试移除指定的轴。如果指定的轴的大小不为 1,则会引发 ValueError

工作原理与示例:

  1. 默认行为 (axis=None): 移除所有单一维度。

    “`python
    arr_a = np.array([[[1, 2, 3]]]) # shape: (1, 1, 3)
    squeezed_a = np.squeeze(arr_a)
    print(f”Original shape: {arr_a.shape}”) # Output: (1, 1, 3)
    print(f”Squeezed shape: {squeezed_a.shape}”) # Output: (3,)
    print(f”Squeezed array: {squeezed_a}”) # Output: [1 2 3]

    arr_b = np.array([[1], [2], [3]]) # shape: (3, 1)
    squeezed_b = np.squeeze(arr_b)
    print(f”\nOriginal shape: {arr_b.shape}”) # Output: (3, 1)
    print(f”Squeezed shape: {squeezed_b.shape}”) # Output: (3,)
    print(f”Squeezed array: {squeezed_b}”) # Output: [1 2 3]

    arr_c = np.array([[[[4]]]]) # shape: (1, 1, 1, 1)
    squeezed_c = np.squeeze(arr_c)
    print(f”\nOriginal shape: {arr_c.shape}”) # Output: (1, 1, 1, 1)
    print(f”Squeezed shape: {squeezed_c.shape}”) # Output: () – A 0D scalar array
    print(f”Squeezed array: {squeezed_c}”) # Output: 4
    print(f”Squeezed array ndim: {squeezed_c.ndim}”) # Output: 0

    arr_d = np.array([1, 2, 3]) # shape: (3,) – No dimensions of size 1
    squeezed_d = np.squeeze(arr_d)
    print(f”\nOriginal shape: {arr_d.shape}”) # Output: (3,)
    print(f”Squeezed shape: {squeezed_d.shape}”) # Output: (3,) – No change
    “`

  2. 指定 axis: 只移除特定位置的大小为 1 的维度。

    “`python
    arr_e = np.zeros((1, 5, 1, 8)) # shape: (1, 5, 1, 8)

    Remove only the first axis (axis=0)

    squeezed_e0 = np.squeeze(arr_e, axis=0)
    print(f”Original shape: {arr_e.shape}”) # Output: (1, 5, 1, 8)
    print(f”Squeezed axis 0 shape: {squeezed_e0.shape}”) # Output: (5, 1, 8)

    Remove only the third axis (axis=2)

    squeezed_e2 = np.squeeze(arr_e, axis=2)
    print(f”Squeezed axis 2 shape: {squeezed_e2.shape}”) # Output: (1, 5, 8)

    Remove both axis 0 and axis 2

    squeezed_e02 = np.squeeze(arr_e, axis=(0, 2))
    print(f”Squeezed axis (0, 2) shape: {squeezed_e02.shape}”) # Output: (5, 8)

    Attempt to squeeze a non-singleton dimension (axis=1, size=5) – Raises Error

    try:
    np.squeeze(arr_e, axis=1)
    except ValueError as e:
    print(f”\nError when squeezing non-singleton axis: {e}”)
    # Output: Error when squeezing non-singleton axis: cannot select an axis to squeeze out which has size not equal to one
    “`

视图 (View) vs. 副本 (Copy):

np.squeeze() 通常返回原始数组的 视图 (View),而不是副本 (Copy)。这意味着返回的数组与原始数组共享同一块内存数据。修改视图会影响原始数组,反之亦然。这使得 squeeze 操作非常高效,因为它避免了不必要的数据复制。只有在无法创建视图的情况下(虽然对于 squeeze 来说比较少见),它才可能返回副本。

“`python
arr_f = np.array([[[10, 20]]]) # shape (1, 1, 2)
squeezed_f = np.squeeze(arr_f) # shape (2,)

print(f”\nOriginal array (f) before modification:\n{arr_f}”)
print(f”Squeezed array (f) before modification: {squeezed_f}”)

squeezed_f[0] = 99 # Modify the squeezed array (view)

print(f”\nOriginal array (f) after modification:\n{arr_f}”) # Original is changed!
print(f”Squeezed array (f) after modification: {squeezed_f}”)
“`

使用场景:

  • 清理冗余维度: 在进行某些 NumPy 操作后(例如,使用 keepdims=True 的聚合函数,或者某些索引操作),可能会产生单一维度。squeeze 是去除这些冗余维度的理想工具,使数据结构更简洁。
  • 函数/库接口适配: 有些函数或库可能期望接收特定维度的输入,如果你的数据恰好多了几个单一维度,squeeze 可以快速调整。

三、 numpy.reshape():通用形状变换

numpy.reshape() 函数(或等效的 ndarray.reshape() 方法)是一个更通用的维度变换工具。它允许你将数组改变为任何兼容的新形状,只要新形状所包含的元素总数与原始数组相同。

语法:

“`python
numpy.reshape(a, newshape, order=’C’)

或者

ndarray.reshape(newshape, order=’C’)
“`

  • a (或 ndarray): 输入的 NumPy 数组。
  • newshape: 一个整数元组,指定目标形状。
    • 其中一个维度可以指定为 -1。在这种情况下,NumPy 会自动计算该维度的大小,以确保总元素数量不变。
  • order: 可选参数,指定元素在内存中的读取/写入顺序。
    • 'C' (默认): C 语言风格的行优先顺序 (row-major)。
    • 'F': Fortran 风格的列优先顺序 (column-major)。
    • 'A': 如果 a 在内存中是 Fortran 连续的,则按列优先顺序;否则按行优先顺序。

工作原理与示例:

核心约束:np.prod(a.shape) == np.prod(newshape) 必须为真。

  1. 基本重塑:

    “`python
    arr_g = np.arange(12) # shape: (12,) -> [0, 1, 2, …, 11]
    print(f”Original array (g): {arr_g}, shape: {arr_g.shape}”)

    reshaped_g1 = np.reshape(arr_g, (3, 4)) # Reshape to 3×4 matrix
    print(f”\nReshaped to (3, 4):\n{reshaped_g1}”)
    print(f”Shape: {reshaped_g1.shape}”) # Output: (3, 4)

    reshaped_g2 = arr_g.reshape((2, 6)) # Using the method syntax
    print(f”\nReshaped to (2, 6):\n{reshaped_g2}”)
    print(f”Shape: {reshaped_g2.shape}”) # Output: (2, 6)

    reshaped_g3 = arr_g.reshape((2, 2, 3)) # Reshape to 3D array
    print(f”\nReshaped to (2, 2, 3):\n{reshaped_g3}”)
    print(f”Shape: {reshaped_g3.shape}”) # Output: (2, 2, 3)
    “`

  2. 使用 -1 自动计算维度:

    “`python
    arr_h = np.arange(24) # shape: (24,)

    Reshape to (4, ?) -> NumPy calculates ? = 24 / 4 = 6

    reshaped_h1 = arr_h.reshape((4, -1))
    print(f”\nOriginal shape: {arr_h.shape}”) # Output: (24,)
    print(f”Reshaped to (4, -1) shape: {reshaped_h1.shape}”) # Output: (4, 6)

    Reshape to (?, 3, 2) -> NumPy calculates ? = 24 / (3 * 2) = 4

    reshaped_h2 = arr_h.reshape((-1, 3, 2))
    print(f”Reshaped to (-1, 3, 2) shape: {reshaped_h2.shape}”) # Output: (4, 3, 2)

    Error case: Total size mismatch

    try:
    arr_h.reshape((5, -1)) # 24 is not divisible by 5
    except ValueError as e:
    print(f”\nError reshaping to incompatible size: {e}”)
    # Output: Error reshaping to incompatible size: cannot reshape array of size 24 into shape (5,newaxis) or similar
    “`

  3. order 参数的影响:

    order 参数决定了元素如何从旧形状“读取”并填充到新形状。

    “`python
    arr_i = np.array([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)

    Default order=’C’ (row-major): Read row by row [1, 2, 3, 4, 5, 6]

    reshaped_i_C = arr_i.reshape((3, 2), order=’C’)
    print(f”\nOriginal array (i):\n{arr_i}”)
    print(f”Reshaped to (3, 2) with order=’C’:\n{reshaped_i_C}”)

    Output:

    [[1 2]

    [3 4]

    [5 6]]

    Order=’F’ (column-major): Read column by column [1, 4, 2, 5, 3, 6]

    reshaped_i_F = arr_i.reshape((3, 2), order=’F’)
    print(f”\nReshaped to (3, 2) with order=’F’:\n{reshaped_i_F}”)

    Output:

    [[1 5]

    [4 3]

    [2 6]]

    “`

视图 (View) vs. 副本 (Copy):

reshape 尽可能地返回视图,以提高效率。然而,在某些情况下,它必须返回副本。这通常发生在:

  • 请求的 order (‘C’ 或 ‘F’) 与数组当前的内存布局不兼容,无法在不移动数据的情况下创建新形状的视图。
  • 数组不是 C 连续 (C-contiguous) 或 F 连续 (Fortran-contiguous) 的,并且重塑操作需要改变这种连续性以匹配 order 参数。

判断 reshape 返回的是视图还是副本的一个方法是检查其 base 属性。如果 reshaped_array.base 是原始数组 original_array,则 reshaped_array 是一个视图。如果是 None,则它是一个副本。或者使用 np.shares_memory(original_array, reshaped_array)

“`python
arr_j = np.arange(6) # shape (6,), C-contiguous
reshaped_j_view = arr_j.reshape((2, 3)) # Typically returns a view
print(f”\nIs reshaped_j_view a view? {np.shares_memory(arr_j, reshaped_j_view)}”) # Output: True

Create a non-contiguous array via slicing

arr_k = np.arange(12).reshape((3, 4))
arr_k_non_contig = arr_k[:, ::2] # Select columns 0 and 2, shape (3, 2), not C-contiguous
print(f”Is arr_k_non_contig C-contiguous? {arr_k_non_contig.flags[‘C_CONTIGUOUS’]}”) # Output: False

Reshaping a non-contiguous array might return a copy

reshaped_k_copy = arr_k_non_contig.reshape((6,)) # May return a copy
print(f”Is reshaped_k_copy a view? {np.shares_memory(arr_k_non_contig, reshaped_k_copy)}”) # Might be False

Using order=’F’ on a C-ordered array often forces a copy

reshaped_j_copy = arr_j.reshape((2, 3), order=’F’)
print(f”Is reshaped_j_copy (order=’F’) a view? {np.shares_memory(arr_j, reshaped_j_copy)}”) # Often False
“`

使用场景:

  • 数据准备: 机器学习模型通常需要特定形状的输入(如展平的向量、带有批次维度的张量)。reshape 是实现这些转换的核心工具。
  • 维度调整: 将数据从一种多维结构转换为另一种,例如将时间序列数据从 (days, hours) 转换为 (weeks, days_in_week, hours)
  • 展平数组: 使用 reshape(-1) 可以快速将任意维度的数组展平成 1D 数组。
  • 添加维度: 可以使用 reshape 添加维度,例如将 (H, W) 的图像 reshape 成 (1, H, W) 以添加批次维度(虽然使用 np.newaxis 通常更直观)。

四、 Squeeze vs. Reshape:核心差异对比

现在我们来总结一下 squeezereshape 的关键区别:

特性 numpy.squeeze() numpy.reshape()
核心功能 移除大小为 1 的维度 改变为任意兼容的新形状
目标形状 由移除单一维度后 自动确定 由用户 明确指定 (可使用 -1 推断)
维度变化 只能 减少 维度数量 (或不变) 可以 增加减少保持 维度数量不变
元素总数 保持不变 必须 保持不变
选择性 可通过 axis 参数指定移除哪些单一维度 通过 newshape 参数定义整个目标形状
order 参数 有 ('C', 'F', 'A'),影响元素排布
返回值 通常 返回视图 尽量 返回视图,但可能返回副本
主要用途 清理冗余维度、简化数据结构 通用形状变换、数据准备、维度增删改
限制 只能移除大小为 1 的维度 新旧形状的元素总数必须严格相等

何时使用哪个?

  • 当你需要去除那些大小为 1 的“假”维度时,请使用 squeeze 这是它的专长,代码意图清晰,且通常高效(返回视图)。例如,处理 keepdims=True 的聚合结果。

    “`python
    data = np.random.rand(1, 10, 1, 5) # shape (1, 10, 1, 5)
    result = np.mean(data, axis=(0, 2), keepdims=True) # shape (1, 1, 1, 5)

    Need shape (5,)

    final_result = np.squeeze(result) # Correct tool, shape (5,)
    print(f”Shape after mean with keepdims: {result.shape}”)
    print(f”Shape after squeeze: {final_result.shape}”)
    “`

  • 当你需要将数组变成一个 特定 的新形状(无论是否涉及单一维度)时,请使用 reshape 只要元素总数匹配,reshape 就能完成任务。例如,为神经网络准备输入数据。

    “`python
    images = np.random.rand(100, 28, 28) # 100 images of 28×28 pixels

    Need to flatten each image for a Dense layer -> shape (100, 784)

    flattened_images = images.reshape((100, -1)) # Correct tool

    Or equivalently: flattened_images = images.reshape(100, 784)

    print(f”\nOriginal images shape: {images.shape}”)
    print(f”Flattened images shape: {flattened_images.shape}”)

    Example: Reshape can also mimic squeeze, but less directly

    arr_to_squeeze = np.array([[[1, 2, 3]]]) # shape (1, 1, 3)
    reshaped_to_mimic_squeeze = arr_to_squeeze.reshape((3,)) # Works, but squeeze is clearer
    print(f”Shape using reshape to mimic squeeze: {reshaped_to_mimic_squeeze.shape}”)
    “`

  • 不要用 reshape 来替代 squeeze 的主要功能。 虽然 reshape 可以通过手动计算目标形状来达到 squeeze 的效果(例如,如果 arr.shape(1, 5, 1)arr.reshape(5) 可以得到 (5,)),但这不如 np.squeeze(arr) 来得直接和清晰。squeeze 明确表达了“移除单一维度”的意图。

  • squeeze 不能替代 reshape 的通用性。 squeeze 无法实现任意形状变换,例如将 (6,) 变成 (2, 3)

五、 性能考量与相关函数

  • 性能: 由于 squeezereshape 都倾向于返回视图,它们通常都是非常快速的操作,因为它们主要修改的是元数据(形状、步长 stride 信息),而不是复制底层数据。当 reshape 被迫返回副本时(例如,由于 order 参数或内存不连续),会有额外的内存分配和数据复制开销。
  • 相关函数:
    • ravel()flatten(): 这两个函数都用于将多维数组“展平”为一维数组。ravel() 倾向于返回视图(类似于 reshape(-1)),而 flatten() 总是 返回一个副本。
    • np.newaxis (或 None): 用于在特定位置 增加 一个大小为 1 的新维度。这在某种程度上是 squeeze 的逆操作。常用于索引操作中。
      python
      arr = np.array([1, 2, 3]) # shape (3,)
      arr_row_vec = arr[np.newaxis, :] # shape (1, 3)
      arr_col_vec = arr[:, np.newaxis] # shape (3, 1)
      print(f"\nUsing np.newaxis for row vector: {arr_row_vec.shape}")
      print(f"Using np.newaxis for column vector: {arr_col_vec.shape}")
    • .Tnp.transpose(): 用于转置数组,即交换轴的顺序,但不改变每个轴的大小。这与 reshape 不同,reshape 可以改变轴的大小。

六、 总结

NumPy 的 squeezereshape 是数组维度变换工具箱中的两把利器,各自扮演着不同的角色:

  • np.squeeze() 是一个 专用的 工具,用于 移除数组形状中大小为 1 的维度。它简洁、意图明确,并且通常返回高效的视图。当你需要清理由其他操作产生的冗余单一维度时,它是首选。
  • np.reshape() 是一个 通用的 工具,用于将数组 变换为任何元素总数相同的新形状。它灵活、强大,是数据准备和结构重组的核心函数。它尽可能返回视图,但在某些情况下会返回副本。

理解它们之间的核心差异——squeeze 的条件性移除与 reshape 的通用性变换——是关键。虽然 reshape 在功能上可以模拟 squeeze,但在表达“移除单一维度”这一特定意图时,squeeze 更加清晰和直接。

熟练掌握 squeezereshape,并根据具体需求选择合适的函数,将使你的 NumPy 代码更加高效、可读,并能更好地应对数据处理和算法实现中各种复杂的维度变换挑战。它们是 NumPy 赋予我们的强大能力的一部分,帮助我们自如地塑造和操控多维数据。


发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

滚动至顶部