理解 NumPy `expand_dims`:为数组增加新维度 – wiki基地


深入理解 NumPy expand_dims:为数组增加新维度

引言:NumPy 数组与维度的重要性

在科学计算、数据分析、机器学习等领域,NumPy 库是 Python 中不可或缺的工具。它的核心是多维数组(ndarray),这种数据结构提供了高效的数值计算能力。理解 NumPy 数组的关键在于理解其“形状”(shape)和“维度”(dimensions,也称为 axes)。

数组的形状是一个元组,表示数组在每个维度上的大小。例如,一个形状为 (3,) 的数组是一个包含 3 个元素的一维数组(向量);一个形状为 (2, 3) 的数组是一个包含 2 行 3 列的二维数组(矩阵);一个形状为 (4, 2, 3) 的数组则是一个三维数组,可以想象成 4 个 2×3 的矩阵堆叠在一起。

数组的维度(axes)是从 0 开始编号的。对于一个 (d0, d1, d2, ..., dn) 形状的 n 维数组,维度 0 对应大小为 d0 的轴,维度 1 对应大小为 d1 的轴,以此类推,维度 n-1 对应大小为 dn 的轴。例如,在 (2, 3) 的二维数组中,维度 0 是行(大小为 2),维度 1 是列(大小为 3)。

在实际应用中,我们经常需要改变数组的形状或增加/减少其维度。这可能是出于以下原因:

  1. 满足函数或模型的输入要求: 许多库(如深度学习框架 TensorFlow 或 PyTorch)或函数对输入数据的维度有特定要求。例如,卷积神经网络通常期望输入是 (批次大小, 高度, 宽度, 通道数) 这样的四维数组,即使你只有一个单张图片 (高度, 宽度, 通道数),也需要为其增加一个批次大小的维度,使其变为 (1, 高度, 宽度, 通道数)
  2. 广播(Broadcasting): NumPy 的广播机制允许在形状不同的数组之间进行算术运算,前提是它们的形状满足一定的兼容性规则。有时,为了让两个数组能够成功广播,我们需要通过增加维度来调整其中一个数组的形状。
  3. 数据处理和转换: 在数据的预处理、特征工程等步骤中,可能需要调整数组的维度以便进行特定的操作。

NumPy 提供了多种改变数组形状的方法,例如 reshape 方法。然而,当我们的目标仅仅是“在特定位置增加一个大小为 1 的新维度”时,numpy.expand_dims 函数提供了一种更加直观和安全的方式。

本文将详细探讨 numpy.expand_dims 函数的功能、用法、参数以及它在不同场景下的应用,并与 reshapenp.newaxis 进行比较。

什么是 numpy.expand_dims

numpy.expand_dims(a, axis) 是 NumPy 提供的一个函数,它的主要作用是在数组 a 指定的 axis 位置插入一个新维度。这个新维度的长度总是 1。

函数签名:

python
numpy.expand_dims(a, axis)

参数解释:

  • a: 输入的 NumPy 数组(ndarray)。
  • axis: 一个整数或一个由整数组成的元组。指定新维度插入的位置。
    • 如果 axis 是一个整数,它表示新维度将被插入到该索引位置之前。例如,axis=0 会在最前面插入一个新维度,axis=1 会在原来的维度 0 之后、维度 1 之前插入新维度。
    • axis 的值可以为负数。负数索引的工作方式与 Python 列表类似,-1 表示倒数第一个维度之后(也就是最后面),-2 表示倒数第二个维度之前,以此类推。需要注意的是,axis 的取值范围是从 -ndim - 1ndim(其中 ndim 是输入数组 a 的原始维度数)。例如,对于一个 2 维数组(ndim=2),合法的 axis 取值范围是 -32axis=-3 等同于 axis=0axis=2 等同于 axis=-1
    • 如果 axis 是一个元组,则会在元组指定的多个位置同时插入新的维度。元组中的值必须是唯一的,并且按照插入位置的顺序给出。NumPy 官方文档提到,当 axis 是元组时,插入是按位置 从低到高 进行的。例如,对于一个形状为 (2, 3) 的数组,axis=(0, 2) 意味着先在 axis=0 处插入一个维度(变成 (1, 2, 3)),然后再在新的 axis=2 处插入一个维度(变成 (1, 2, 1, 3))。这与先在原数组的 axis=0 插入再在原数组的 axis=2 插入是不同的。实际操作时,axis=(0, 2) 会在原数组的 axis=0 插入,然后在 结果数组axis=2 插入。这个行为在使用元组时需要特别注意。不过,在大多数常见场景下,我们通常只插入一个维度,此时 axis 是一个整数。

返回值:

  • 返回一个新的 NumPy 数组,其维度比输入数组 a 多了 len(axis) 个(如果 axis 是元组)或 1 个(如果 axis 是整数)。新插入维度的长度均为 1。

核心特性:

  • expand_dims 不会改变数组元素的总数。
  • expand_dims 在指定位置插入一个长度为 1 的新轴。
  • expand_dims 通常返回一个视图(view)而不是数据的副本(copy),这意味着操作是高效的,不会额外消耗大量内存,除非原始数组不是 C-contiguous 或 Fortran-contiguous。

expand_dims 的工作原理及 axis 参数详解

理解 expand_dims 的关键在于理解 axis 参数如何指定新维度的位置。想象一下数组的维度就像是嵌套的括号层级。

  • 一维数组 [1, 2, 3],形状 (3,)。只有一个维度,索引为 0。
  • 二维数组 [[1, 2], [3, 4]],形状 (2, 2)。维度 0 是行(外层括号),维度 1 是列(内层括号)。

expand_dims(a, axis) 会在原始数组 aaxis 索引 之前 插入一个新的维度。

示例 1:一维数组增加维度

考虑一个一维数组 arr = np.array([1, 2, 3]),其形状为 (3,),维度索引为 0。

  • axis=0 处插入:
    “`python
    import numpy as np

    arr = np.array([1, 2, 3])
    print(“原数组:”, arr)
    print(“原形状:”, arr.shape)
    print(“原维度:”, arr.ndim)

    在 axis=0 (最前面) 插入新维度

    arr_expanded_0 = np.expand_dims(arr, axis=0)
    print(“\nexpand_dims(arr, axis=0):”)
    print(“新数组:”, arr_expanded_0)
    print(“新形状:”, arr_expanded_0.shape)
    print(“新维度:”, arr_expanded_0.ndim)
    输出:
    原数组: [1 2 3]
    原形状: (3,)
    原维度: 1

    expand_dims(arr, axis=0):
    新数组: [[1 2 3]]
    新形状: (1, 3)
    新维度: 2
    ``
    原数组
    [1, 2, 3]只有一个维度(索引 0,大小 3)。axis=0表示在索引 0 之前插入。插入后,原来的维度 0 变成了新的维度 1,而新的维度 0 被插入并具有大小 1。形状从(3,)变为(1, 3)。这相当于将一个行向量[1, 2, 3]变成一个 1x3 的矩阵[[1, 2, 3]]`。

  • axis=1 处插入:
    一个一维数组只有 axis=0,所以 axis=1 是无效的,因为它超过了原始维度数(ndim=1)。但是,根据规则,axis 的范围是 -ndim - 1ndim。对于 ndim=1,范围是 -21axis=1 实际上是合法的,它等同于 axis=-1,表示在最后一个维度之后插入。
    “`python
    # 在 axis=1 (最后面) 插入新维度
    arr_expanded_1 = np.expand_dims(arr, axis=1)
    print(“\nexpand_dims(arr, axis=1):”)
    print(“新数组:”, arr_expanded_1)
    print(“新形状:”, arr_expanded_1.shape)
    print(“新维度:”, arr_expanded_1.ndim)

    使用负数索引 axis=-1 (最后面)

    arr_expanded_neg1 = np.expand_dims(arr, axis=-1)
    print(“\nexpand_dims(arr, axis=-1):”)
    print(“新数组:”, arr_expanded_neg1)
    print(“新形状:”, arr_expanded_neg1.shape)
    print(“新维度:”, arr_expanded_neg1.ndim)
    输出:
    expand_dims(arr, axis=1):
    新数组: [[1]
    [2]
    [3]]
    新形状: (3, 1)
    新维度: 2

    expand_dims(arr, axis=-1):
    新数组: [[1]
    [2]
    [3]]
    新形状: (3, 1)
    新维度: 2
    ``axis=1axis=-1在这里都表示在原始数组的 *最后一个* 维度(即索引 0)之后插入新维度。原来的维度 0 仍然是新的维度 0,而新的维度 1 被插入并具有大小 1。形状从(3,)变为(3, 1)。这相当于将一个行向量[1, 2, 3]变成一个 3x1 的列向量[[1], [2], [3]]`。

示例 2:二维数组增加维度

考虑一个二维数组 arr_2d = np.array([[1, 2], [3, 4]]),其形状为 (2, 2),维度索引为 0 和 1。

  • axis=0 处插入:
    “`python
    arr_2d = np.array([[1, 2], [3, 4]])
    print(“原数组 (2D):”)
    print(arr_2d)
    print(“原形状:”, arr_2d.shape) # (2, 2)

    在 axis=0 (最前面) 插入新维度

    arr_expanded_0_2d = np.expand_dims(arr_2d, axis=0)
    print(“\nexpand_dims(arr_2d, axis=0):”)
    print(arr_expanded_0_2d)
    print(“新形状:”, arr_expanded_0_2d.shape) # (1, 2, 2)
    print(“新维度:”, arr_expanded_0_2d.ndim)
    输出:
    原数组 (2D):
    [[1 2]
    [3 4]]
    原形状: (2, 2)

    expand_dims(arr_2d, axis=0):
    [[[1 2]
    [3 4]]]
    新形状: (1, 2, 2)
    新维度: 3
    ``
    原始形状
    (2, 2),维度索引 0 和 1。axis=0在最前面插入。新维度 0 大小为 1,原来的维度 0 (大小 2) 变成新的维度 1,原来的维度 1 (大小 2) 变成新的维度 2。形状变为(1, 2, 2)`。

  • axis=1 处插入:
    python
    # 在 axis=1 插入新维度
    arr_expanded_1_2d = np.expand_dims(arr_2d, axis=1)
    print("\nexpand_dims(arr_2d, axis=1):")
    print(arr_expanded_1_2d)
    print("新形状:", arr_expanded_1_2d.shape) # (2, 1, 2)
    print("新维度:", arr_expanded_1_2d.ndim)

    输出:
    “`
    expand_dims(arr_2d, axis=1):
    [[[1 2]]

    [[3 4]]]
    新形状: (2, 1, 2)
    新维度: 3
    ``
    原始形状
    (2, 2),维度索引 0 和 1。axis=1在索引 1 之前插入。原来的维度 0 (大小 2) 保持为新的维度 0。新的维度 1 被插入并具有大小 1。原来的维度 1 (大小 2) 变成新的维度 2。形状变为(2, 1, 2)`。

  • axis=2 处插入(等同于 axis=-1):
    对于 2D 数组,合法 axis 范围是 -32axis=2 是合法的,它表示在最后一个维度之后插入。
    “`python
    # 在 axis=2 (最后面) 插入新维度
    arr_expanded_2_2d = np.expand_dims(arr_2d, axis=2)
    print(“\nexpand_dims(arr_2d, axis=2):”)
    print(arr_expanded_2_2d)
    print(“新形状:”, arr_expanded_2_2d.shape) # (2, 2, 1)
    print(“新维度:”, arr_expanded_2_2d.ndim)

    使用负数索引 axis=-1 (最后面)

    arr_expanded_neg1_2d = np.expand_dims(arr_2d, axis=-1)
    print(“\nexpand_dims(arr_2d, axis=-1):”)
    print(arr_expanded_neg1_2d)
    print(“新形状:”, arr_expanded_neg1_2d.shape) # (2, 2, 1)
    print(“新维度:”, arr_expanded_neg1_2d.ndim)
    输出:
    expand_dims(arr_2d, axis=2):
    [[[1] [2]]

    [[3] [4]]]
    新形状: (2, 2, 1)
    新维度: 3

    expand_dims(arr_2d, axis=-1):
    [[[1] [2]]

    [[3] [4]]]
    新形状: (2, 2, 1)
    新维度: 3
    ``
    原始形状
    (2, 2),维度索引 0 和 1。axis=2axis=-1在索引 1 之后插入。原来的维度 0 (大小 2) 保持为新的维度 0。原来的维度 1 (大小 2) 保持为新的维度 1。新的维度 2 被插入并具有大小 1。形状变为(2, 2, 1)`。

总结 axis 参数的行为:

对于形状为 (d0, d1, ..., dk, ..., dn-1) 的 n 维数组 a

  • np.expand_dims(a, axis=i) (当 i >= 0) 会生成一个形状为 (d0, d1, ..., di-1, 1, di, ..., dn-1) 的 n+1 维数组。新的维度被插入到索引 i 处。
  • np.expand_dims(a, axis=-i) (当 i > 0) 会生成一个形状为 (d0, d1, ..., dn-i-1, 1, dn-i, ..., dn-1) 的 n+1 维数组。新的维度被插入到距离末尾第 i 个维度之前。axis=-1 插入在最后面,axis=-2 插入在倒数第二个位置之前,以此类推。
  • 合法的 axis 整数值范围是 [-ndim - 1, ndim]。在这个范围内,axis=i 等同于 axis=i - ndim - 1i >= 0,以及 axis=i + ndim + 1i < -ndim - 1 (虽然 -ndim - 1 是最小值,所以第二个等式只适用于大于 -ndim - 1 的负数)。简单的记法是,正数 i 插入在索引 i 之前,负数 -i 插入在距离末尾第 i 个位置之前。

使用元组作为 axis

如前所述,使用元组可以在多个位置插入维度。插入的顺序是按照元组中 axis 值从小到大进行,但插入位置是相对于 当前 形状而言的。

“`python
arr_2d = np.array([[1, 2], [3, 4]]) # Shape (2, 2)

在 axis=0 和 axis=2 (相对于原始数组) 插入

步骤1: 在 axis=0 插入 (shape (1, 2, 2))

步骤2: 在新的数组 (1, 2, 2) 的 axis=2 插入

arr_expanded_tuple = np.expand_dims(arr_2d, axis=(0, 2))
print(“\nexpand_dims(arr_2d, axis=(0, 2)):”)
print(arr_expanded_tuple)
print(“新形状:”, arr_expanded_tuple.shape) # (1, 2, 1, 2)
print(“新维度:”, arr_expanded_tuple.ndim)
输出:
expand_dims(arr_2d, axis=(0, 2)):
[[[[1 2]]

[[3 4]]]]
新形状: (1, 2, 1, 2)
新维度: 4
``
原始形状
(2, 2)
1. 先在
axis=0插入:结果形状(1, 2, 2)
2. 然后,在新的
(1, 2, 2)形状的axis=2位置插入:新维度插入在索引 2 之前。原来的维度 0 (大小 1) 保持为新的维度 0。原来的维度 1 (大小 2) 保持为新的维度 1。新的维度 2 被插入 (大小 1)。原来的维度 2 (大小 2) 变成新的维度 3。最终形状(1, 2, 1, 2)`。

这种使用元组的情况相对少见,且行为可能有些反直觉(相对于原始数组轴的相对位置)。更常见的是一次只通过整数 axis 参数增加一个维度。如果需要增加多个维度,可以连续调用 expand_dims 或使用其他方法。

expand_dims 的常见应用场景

expand_dims 函数的简洁性使其在多种常见场景下成为首选工具。

  1. 为深度学习模型准备输入数据:
    许多深度学习模型(如使用 TensorFlow 或 PyTorch 构建的卷积神经网络 CNN)期望输入数据包含一个“批次大小”(batch size)的维度作为第一个维度。即使你只想处理一张图片或一个数据样本,也需要将其包装在一个大小为 1 的批次中。

    • 图片数据: 单张彩色图片的数据通常是 (高度, 宽度, 通道数) 的三维数组。模型可能期望 (批次大小, 高度, 宽度, 通道数)
      “`python
      import numpy as np

      模拟一张 28×28 的灰度图片数据

      img_gray = np.random.rand(28, 28) # Shape (28, 28)

      模拟一张 28×28 的彩色图片数据

      img_color = np.random.rand(28, 28, 3) # Shape (28, 28, 3)

      为灰度图片增加批次维度和通道维度 (如果模型期望 (batch, height, width, channel))

      先增加批次维度 (axis=0)

      img_gray_batch = np.expand_dims(img_gray, axis=0) # Shape (1, 28, 28)

      再增加通道维度 (axis=3 或 axis=-1)

      img_gray_batch_channel = np.expand_dims(img_gray_batch, axis=-1) # Shape (1, 28, 28, 1)
      print(f”灰度图转模型输入形状: {img_gray_batch_channel.shape}”)

      为彩色图片增加批次维度 (axis=0)

      img_color_batch = np.expand_dims(img_color, axis=0) # Shape (1, 28, 28, 3)
      print(f”彩图转模型输入形状: {img_color_batch.shape}”)
      输出:
      灰度图转模型输入形状: (1, 28, 28, 1)
      彩图转模型输入形状: (1, 28, 28, 3)
      ``
      这里我们使用了
      expand_dims简洁地添加了批次维度。对于灰度图转(batch, height, width, channel)形状,我们分了两步,先加批次维度,再加通道维度。使用expand_dims` 更清晰地表达了“增加一个维度”的意图。

    • 序列数据: 循环神经网络 RNN 或 Transformer 模型处理序列数据时,通常期望输入是 (批次大小, 序列长度, 特征数)。单个序列数据 (序列长度, 特征数) 需要增加批次维度。
      “`python
      # 模拟一个序列数据,序列长度为 50,每个时间步有 10 个特征
      sequence_data = np.random.rand(50, 10) # Shape (50, 10)

      为序列数据增加批次维度 (axis=0)

      sequence_batch = np.expand_dims(sequence_data, axis=0) # Shape (1, 50, 10)
      print(f”序列数据转模型输入形状: {sequence_batch.shape}”)
      输出:
      序列数据转模型输入形状: (1, 50, 10)
      “`

  2. 实现特定类型的广播操作:
    广播是 NumPy 的强大特性,但有时需要手动调整数组形状才能满足广播条件。增加一个大小为 1 的维度是实现广播兼容性的常见手段。NumPy 的广播规则中重要的一条是:当两个数组的维度数不同时,维度数较小的数组会在其前部(左侧)填充大小为 1 的维度,直到两个数组维度数相同。然后比较对应维度的长度,如果相等或其中一个长度为 1,则兼容。

    考虑两个一维数组 a = np.array([10, 20, 30]) (形状 (3,)) 和 b = np.array([1, 2, 3]) (形状 (3,))。它们的乘积 a * b 是元素级的 [10, 40, 90] (形状 (3,))。

    如果我们想计算它们的“外积”(outer product),即得到一个 3×3 的矩阵,其中元素 M[i, j] = a[i] * b[j]。这可以通过广播实现,但需要将一个数组变为列向量形状 (3, 1),另一个变为行向量形状 (1, 3)

    “`python
    a = np.array([10, 20, 30]) # Shape (3,)
    b = np.array([1, 2, 3]) # Shape (3,)

    将 a 变为列向量形状 (3, 1)

    a_col = np.expand_dims(a, axis=1) # Shape (3, 1)
    print(f”a 变为列向量形状: {a_col.shape}”)

    将 b 变为行向量形状 (1, 3)

    b_row = np.expand_dims(b, axis=0) # Shape (1, 3)
    print(f”b 变为行向量形状: {b_row.shape}”)

    计算外积: (3, 1) 和 (1, 3) 进行广播

    outer_product = a_col * b_row
    print(“\n外积结果:”)
    print(outer_product)
    print(“外积形状:”, outer_product.shape) # Shape (3, 3)
    输出:
    a 变为列向量形状: (3, 1)
    b 变为行向量形状: (1, 3)

    外积结果:
    [[10 20 30]
    [20 40 60]
    [30 60 90]]
    外积形状: (3, 3)
    ``
    在这个例子中,
    expand_dims清晰地表达了我们的意图:将a转换为带有新列维度的数组,将b` 转换为带有新行维度的数组,从而实现广播。

  3. 数据可视化库的输入要求:
    一些绘图库可能对输入数据的形状有特定要求。例如,一个函数可能期望输入是一个二维数组 (样本数, 特征数),而你的原始数据是每个样本一个一维数组列表。你可以将每个一维数组通过 expand_dims(sample, axis=0) 转换为 (1, 特征数),然后使用 np.concatenate 将它们堆叠起来形成 (样本数, 特征数) 的二维数组。

expand_dimsreshape 的区别

reshape 是另一个常用的改变数组形状的方法,但它与 expand_dims 有本质区别。

  • expand_dims: 在指定位置插入一个大小为 1 的新维度。它不改变原始数据的排列方式(内存布局通常保持一致, unless strides change significantly, yielding a view)。它使得数组的维度数量增加 1(或更多,如果 axis 是元组)。新维度的长度总是 1。
  • reshape: 改变数组的形状,但 不改变 数组元素的总数。它可以增加或减少维度,或者改变现有维度的大小,只要新形状的总元素数与原始数组的总元素数相等。reshape 可以返回一个视图或一个副本,取决于新的形状是否与原始数组的内存布局兼容。

示例:

“`python
arr = np.array([1, 2, 3, 4, 5, 6]) # Shape (6,)

使用 expand_dims

arr_expanded = np.expand_dims(arr, axis=0) # Shape (1, 6)
print(f”expand_dims 结果形状: {arr_expanded.shape}”)

使用 reshape

arr_reshaped_row = arr.reshape(1, 6) # Shape (1, 6)
print(f”reshape(1, 6) 结果形状: {arr_reshaped_row.shape}”)

arr_reshaped_col = arr.reshape(6, 1) # Shape (6, 1)
print(f”reshape(6, 1) 结果形状: {arr_reshaped_col.shape}”)

reshape 可以完全改变形状,只要元素总数不变

arr_reshaped_2x3 = arr.reshape(2, 3) # Shape (2, 3)
print(f”reshape(2, 3) 结果形状: {arr_reshaped_2x3.shape}”)

expand_dims 只能增加大小为 1 的维度,不能将 (6,) 变为 (2, 3)

np.expand_dims(arr, axis=?) 只能得到 (1, 6) 或 (6, 1)

“`

总结区别:

特性 numpy.expand_dims(a, axis) a.reshape(new_shape)
目的 在指定位置插入一个大小为 1 的新维度 将数组重塑为指定的形状
维度变化 增加 1 个维度 (或更多,如果 axis 是元组) 可以增加、减少或保持维度数不变
元素总数 保持不变 必须保持不变
新维度大小 总是 1 new_shape 决定,可以是任意整数 (>0)
灵活性 专门用于插入大小为 1 的维度 更通用,可以实现多种形状变换
常见应用 添加批次维度、通道维度、为广播准备 改变数组布局 (如扁平化、改变行列数)

当你的目标是精确地“在某个位置加一个大小为 1 的维度”时,expand_dims 更具表达力,且风险较低(不会意外地将数据重新排列成错误的形状)。

expand_dims 与切片操作中的 Nonenp.newaxis

在 NumPy 中,使用切片操作符结合 Nonenp.newaxis 是另一种在特定位置添加大小为 1 维度的常用方法。np.newaxis 实际上是 None 的别名,两者效果完全一样,使用 np.newaxis 通常被认为更具可读性。

语法是:arr[..., np.newaxis, ...],其中 ... 是省略号,表示保留原始的所有维度,np.newaxis 所在的位置就是新维度插入的位置。

示例:

“`python
arr = np.array([10, 20, 30]) # Shape (3,)

使用 expand_dims 在 axis=0 插入

arr_expanded_0 = np.expand_dims(arr, axis=0) # Shape (1, 3)

使用 np.newaxis 在 axis=0 位置 (逗号前) 插入

arr_newaxis_0 = arr[np.newaxis, :] # Shape (1, 3)

print(f”expand_dims(arr, 0) 形状: {arr_expanded_0.shape}”)
print(f”arr[np.newaxis, :] 形状: {arr_newaxis_0.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_0, arr_newaxis_0)}”) # 值相等

使用 expand_dims 在 axis=1 插入

arr_expanded_1 = np.expand_dims(arr, axis=1) # Shape (3, 1)

使用 np.newaxis 在 axis=1 位置 (逗号后) 插入

arr_newaxis_1 = arr[:, np.newaxis] # Shape (3, 1)

print(f”\nexpand_dims(arr, 1) 形状: {arr_expanded_1.shape}”)
print(f”arr[:, np.newaxis] 形状: {arr_newaxis_1.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_1, arr_newaxis_1)}”) # 值相等

arr_2d = np.array([[1, 2], [3, 4]]) # Shape (2, 2)

在 axis=1 插入

arr_expanded_1_2d = np.expand_dims(arr_2d, axis=1) # Shape (2, 1, 2)
arr_newaxis_1_2d = arr_2d[:, np.newaxis, :] # Shape (2, 1, 2)
print(f”\nexpand_dims(arr_2d, 1) 形状: {arr_expanded_1_2d.shape}”)
print(f”arr_2d[:, np.newaxis, :] 形状: {arr_newaxis_1_2d.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_1_2d, arr_newaxis_1_2d)}”)

在 axis=2 插入

arr_expanded_2_2d = np.expand_dims(arr_2d, axis=2) # Shape (2, 2, 1)
arr_newaxis_2_2d = arr_2d[:, :, np.newaxis] # Shape (2, 2, 1)
print(f”\nexpand_dims(arr_2d, 2) 形状: {arr_expanded_2_2d.shape}”)
print(f”arr_2d[:, :, np.newaxis] 形状: {arr_newaxis_2_2d.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_2_2d, arr_newaxis_2_2d)}”)

使用 np.newaxis 可以同时添加多个维度

arr_newaxis_multi = arr_2d[np.newaxis, :, :, np.newaxis] # Shape (1, 2, 2, 1)
print(f”\narr_2d[np.newaxis, :, :, np.newaxis] 形状: {arr_newaxis_multi.shape}”)

等价于连续调用 expand_dims

arr_expanded_multi = np.expand_dims(np.expand_dims(arr_2d, axis=0), axis=3) # 或者 axis=-1 (相对于第二次expand_dims的结果)
print(f”等效的连续 expand_dims 形状: {arr_expanded_multi.shape}”)
print(f”两者是否相等: {np.array_equal(arr_newaxis_multi, arr_expanded_multi)}”)
输出:
expand_dims(arr, 0) 形状: (1, 3)
arr[np.newaxis, :] 形状: (1, 3)
两者是否相等: True

expand_dims(arr, 1) 形状: (3, 1)
arr[:, np.newaxis] 形状: (3, 1)
两者是否相等: True

expand_dims(arr_2d, 1) 形状: (2, 1, 2)
arr_2d[:, np.newaxis, :] 形状: (2, 1, 2)
两者是否相等: True

expand_dims(arr_2d, 2) 形状: (2, 2, 1)
arr_2d[:, :, np.newaxis] 形状: (2, 2, 1)
两者是否相等: True

arr_2d[np.newaxis, :, :, np.newaxis] 形状: (1, 2, 2, 1)
等效的连续 expand_dims 形状: (1, 2, 2, 1)
两者是否相等: True
“`

总结比较 expand_dimsnp.newaxis

两者在功能上是等效的,都可以用于在指定位置插入大小为 1 的维度。选择使用哪个通常取决于个人偏好和上下文:

  • np.newaxis / None

    • 优点: 语法简洁紧凑,特别适合在进行切片或索引操作的同时添加维度。可以在一个操作中添加多个维度 (arr[None, :, None])。
    • 缺点: 语法可能不如函数调用直观,特别是对于初学者。当轴索引计算复杂时容易出错。
  • np.expand_dims

    • 优点: 函数调用形式,意图明确,可读性强。axis 参数明确指定了新维度的位置,尤其适合动态计算轴索引的场景。
    • 缺点: 一次函数调用只能添加一个维度(除非 axis 是元组,但元组 axis 的行为需要额外理解)。添加多个维度需要连续调用。

在许多情况下,arr[..., np.newaxis] 的切片语法更为流行和常见,因为它与 NumPy 的索引系统紧密结合。然而,np.expand_dims 作为独立的函数,在强调“增加维度”这一操作本身时,提供了更好的可读性,并且在编写更复杂的、涉及到动态轴操作的代码时可能更易于管理。例如,如果你有一个变量 target_axis 存储了需要在哪个轴插入维度,使用 np.expand_dims(arr, axis=target_axis) 比构建一个复杂的切片元组要简单得多。

潜在的陷阱与注意事项

  • 理解 axis 参数: 这是使用 expand_dims 最容易出错的地方。记住 axis=i 是在当前索引 i 之前插入,负数索引 -i 是在当前倒数第 i 个位置之前插入。绘制数组的维度图可以帮助理解。
  • axis 的合法范围: 确保 axis 的值在 [-ndim - 1, ndim] 的范围内。超出范围会导致 ValueError
  • 元组 axis 的行为: 如前所述,当 axis 是元组时,维度是按元组值升序插入的,但插入位置是相对于 上一步操作后的数组形状。这可能与想象中相对于原始数组轴的位置不同。通常建议避免使用元组 axis,除非你完全理解其工作原理,或者通过连续调用 expand_dims 来替代。
  • 视图 vs 副本: 尽管 expand_dims 通常返回视图,但在某些复杂情况下或当原始数组的内存布局不标准时,它可能会返回副本。如果对性能有严格要求,并且需要修改原始数据,应该注意这一点。可以使用 arr.flags['OWNDATA'] 来检查返回的数组是否拥有自己的数据。

总结

numpy.expand_dims 是 NumPy 中一个虽小但功能强大的函数,专门用于在数组的指定位置插入一个大小为 1 的新维度。这个操作在准备数据输入、实现广播、调整数组布局等方面非常有用。

理解 axis 参数是掌握 expand_dims 的关键。通过明确指定新维度插入的前一个索引位置,我们可以精确控制数组形状的变化。

虽然 expand_dims 的功能可以通过 reshape 或使用 np.newaxis 进行切片来实现,但 expand_dims 作为独立的函数,在表达“增加一个维度”这一特定意图时具有更高的可读性和简洁性,尤其是在 axis 索引是变量时。

掌握 expand_dims 以及其他 NumPy 数组形状操作函数,是高效进行数值计算和数据处理的基础。通过本文的详细介绍和示例,希望你能够更深入地理解 expand_dims,并在实际应用中灵活运用它。


发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

滚动至顶部