深入理解 NumPy expand_dims
:为数组增加新维度
引言:NumPy 数组与维度的重要性
在科学计算、数据分析、机器学习等领域,NumPy 库是 Python 中不可或缺的工具。它的核心是多维数组(ndarray
),这种数据结构提供了高效的数值计算能力。理解 NumPy 数组的关键在于理解其“形状”(shape)和“维度”(dimensions,也称为 axes)。
数组的形状是一个元组,表示数组在每个维度上的大小。例如,一个形状为 (3,)
的数组是一个包含 3 个元素的一维数组(向量);一个形状为 (2, 3)
的数组是一个包含 2 行 3 列的二维数组(矩阵);一个形状为 (4, 2, 3)
的数组则是一个三维数组,可以想象成 4 个 2×3 的矩阵堆叠在一起。
数组的维度(axes)是从 0 开始编号的。对于一个 (d0, d1, d2, ..., dn)
形状的 n 维数组,维度 0 对应大小为 d0
的轴,维度 1 对应大小为 d1
的轴,以此类推,维度 n-1 对应大小为 dn
的轴。例如,在 (2, 3)
的二维数组中,维度 0 是行(大小为 2),维度 1 是列(大小为 3)。
在实际应用中,我们经常需要改变数组的形状或增加/减少其维度。这可能是出于以下原因:
- 满足函数或模型的输入要求: 许多库(如深度学习框架 TensorFlow 或 PyTorch)或函数对输入数据的维度有特定要求。例如,卷积神经网络通常期望输入是
(批次大小, 高度, 宽度, 通道数)
这样的四维数组,即使你只有一个单张图片(高度, 宽度, 通道数)
,也需要为其增加一个批次大小的维度,使其变为(1, 高度, 宽度, 通道数)
。 - 广播(Broadcasting): NumPy 的广播机制允许在形状不同的数组之间进行算术运算,前提是它们的形状满足一定的兼容性规则。有时,为了让两个数组能够成功广播,我们需要通过增加维度来调整其中一个数组的形状。
- 数据处理和转换: 在数据的预处理、特征工程等步骤中,可能需要调整数组的维度以便进行特定的操作。
NumPy 提供了多种改变数组形状的方法,例如 reshape
方法。然而,当我们的目标仅仅是“在特定位置增加一个大小为 1 的新维度”时,numpy.expand_dims
函数提供了一种更加直观和安全的方式。
本文将详细探讨 numpy.expand_dims
函数的功能、用法、参数以及它在不同场景下的应用,并与 reshape
和 np.newaxis
进行比较。
什么是 numpy.expand_dims
?
numpy.expand_dims(a, axis)
是 NumPy 提供的一个函数,它的主要作用是在数组 a
指定的 axis
位置插入一个新维度。这个新维度的长度总是 1。
函数签名:
python
numpy.expand_dims(a, axis)
参数解释:
a
: 输入的 NumPy 数组(ndarray
)。axis
: 一个整数或一个由整数组成的元组。指定新维度插入的位置。- 如果
axis
是一个整数,它表示新维度将被插入到该索引位置之前。例如,axis=0
会在最前面插入一个新维度,axis=1
会在原来的维度 0 之后、维度 1 之前插入新维度。 axis
的值可以为负数。负数索引的工作方式与 Python 列表类似,-1
表示倒数第一个维度之后(也就是最后面),-2
表示倒数第二个维度之前,以此类推。需要注意的是,axis
的取值范围是从-ndim - 1
到ndim
(其中ndim
是输入数组a
的原始维度数)。例如,对于一个 2 维数组(ndim=2),合法的axis
取值范围是-3
到2
。axis=-3
等同于axis=0
,axis=2
等同于axis=-1
。- 如果
axis
是一个元组,则会在元组指定的多个位置同时插入新的维度。元组中的值必须是唯一的,并且按照插入位置的顺序给出。NumPy 官方文档提到,当axis
是元组时,插入是按位置 从低到高 进行的。例如,对于一个形状为(2, 3)
的数组,axis=(0, 2)
意味着先在axis=0
处插入一个维度(变成(1, 2, 3)
),然后再在新的axis=2
处插入一个维度(变成(1, 2, 1, 3)
)。这与先在原数组的axis=0
插入再在原数组的axis=2
插入是不同的。实际操作时,axis=(0, 2)
会在原数组的axis=0
插入,然后在 结果数组 的axis=2
插入。这个行为在使用元组时需要特别注意。不过,在大多数常见场景下,我们通常只插入一个维度,此时axis
是一个整数。
- 如果
返回值:
- 返回一个新的 NumPy 数组,其维度比输入数组
a
多了len(axis)
个(如果axis
是元组)或 1 个(如果axis
是整数)。新插入维度的长度均为 1。
核心特性:
expand_dims
不会改变数组元素的总数。expand_dims
在指定位置插入一个长度为 1 的新轴。expand_dims
通常返回一个视图(view)而不是数据的副本(copy),这意味着操作是高效的,不会额外消耗大量内存,除非原始数组不是 C-contiguous 或 Fortran-contiguous。
expand_dims
的工作原理及 axis
参数详解
理解 expand_dims
的关键在于理解 axis
参数如何指定新维度的位置。想象一下数组的维度就像是嵌套的括号层级。
- 一维数组
[1, 2, 3]
,形状(3,)
。只有一个维度,索引为 0。 - 二维数组
[[1, 2], [3, 4]]
,形状(2, 2)
。维度 0 是行(外层括号),维度 1 是列(内层括号)。
expand_dims(a, axis)
会在原始数组 a
的 axis
索引 之前 插入一个新的维度。
示例 1:一维数组增加维度
考虑一个一维数组 arr = np.array([1, 2, 3])
,其形状为 (3,)
,维度索引为 0。
-
在
axis=0
处插入:
“`python
import numpy as nparr = np.array([1, 2, 3])
print(“原数组:”, arr)
print(“原形状:”, arr.shape)
print(“原维度:”, arr.ndim)在 axis=0 (最前面) 插入新维度
arr_expanded_0 = np.expand_dims(arr, axis=0)
print(“\nexpand_dims(arr, axis=0):”)
print(“新数组:”, arr_expanded_0)
print(“新形状:”, arr_expanded_0.shape)
print(“新维度:”, arr_expanded_0.ndim)
输出:
原数组: [1 2 3]
原形状: (3,)
原维度: 1expand_dims(arr, axis=0):
新数组: [[1 2 3]]
新形状: (1, 3)
新维度: 2
``
[1, 2, 3]
原数组只有一个维度(索引 0,大小 3)。
axis=0表示在索引 0 之前插入。插入后,原来的维度 0 变成了新的维度 1,而新的维度 0 被插入并具有大小 1。形状从
(3,)变为
(1, 3)。这相当于将一个行向量
[1, 2, 3]变成一个 1x3 的矩阵
[[1, 2, 3]]`。 -
在
axis=1
处插入:
一个一维数组只有axis=0
,所以axis=1
是无效的,因为它超过了原始维度数(ndim=1)。但是,根据规则,axis
的范围是-ndim - 1
到ndim
。对于 ndim=1,范围是-2
到1
。axis=1
实际上是合法的,它等同于axis=-1
,表示在最后一个维度之后插入。
“`python
# 在 axis=1 (最后面) 插入新维度
arr_expanded_1 = np.expand_dims(arr, axis=1)
print(“\nexpand_dims(arr, axis=1):”)
print(“新数组:”, arr_expanded_1)
print(“新形状:”, arr_expanded_1.shape)
print(“新维度:”, arr_expanded_1.ndim)使用负数索引 axis=-1 (最后面)
arr_expanded_neg1 = np.expand_dims(arr, axis=-1)
print(“\nexpand_dims(arr, axis=-1):”)
print(“新数组:”, arr_expanded_neg1)
print(“新形状:”, arr_expanded_neg1.shape)
print(“新维度:”, arr_expanded_neg1.ndim)
输出:
expand_dims(arr, axis=1):
新数组: [[1]
[2]
[3]]
新形状: (3, 1)
新维度: 2expand_dims(arr, axis=-1):
新数组: [[1]
[2]
[3]]
新形状: (3, 1)
新维度: 2
``
axis=1和
axis=-1在这里都表示在原始数组的 *最后一个* 维度(即索引 0)之后插入新维度。原来的维度 0 仍然是新的维度 0,而新的维度 1 被插入并具有大小 1。形状从
(3,)变为
(3, 1)。这相当于将一个行向量
[1, 2, 3]变成一个 3x1 的列向量
[[1], [2], [3]]`。
示例 2:二维数组增加维度
考虑一个二维数组 arr_2d = np.array([[1, 2], [3, 4]])
,其形状为 (2, 2)
,维度索引为 0 和 1。
-
在
axis=0
处插入:
“`python
arr_2d = np.array([[1, 2], [3, 4]])
print(“原数组 (2D):”)
print(arr_2d)
print(“原形状:”, arr_2d.shape) # (2, 2)在 axis=0 (最前面) 插入新维度
arr_expanded_0_2d = np.expand_dims(arr_2d, axis=0)
print(“\nexpand_dims(arr_2d, axis=0):”)
print(arr_expanded_0_2d)
print(“新形状:”, arr_expanded_0_2d.shape) # (1, 2, 2)
print(“新维度:”, arr_expanded_0_2d.ndim)
输出:
原数组 (2D):
[[1 2]
[3 4]]
原形状: (2, 2)expand_dims(arr_2d, axis=0):
[[[1 2]
[3 4]]]
新形状: (1, 2, 2)
新维度: 3
``
(2, 2)
原始形状,维度索引 0 和 1。
axis=0在最前面插入。新维度 0 大小为 1,原来的维度 0 (大小 2) 变成新的维度 1,原来的维度 1 (大小 2) 变成新的维度 2。形状变为
(1, 2, 2)`。 -
在
axis=1
处插入:
python
# 在 axis=1 插入新维度
arr_expanded_1_2d = np.expand_dims(arr_2d, axis=1)
print("\nexpand_dims(arr_2d, axis=1):")
print(arr_expanded_1_2d)
print("新形状:", arr_expanded_1_2d.shape) # (2, 1, 2)
print("新维度:", arr_expanded_1_2d.ndim)
输出:
“`
expand_dims(arr_2d, axis=1):
[[[1 2]][[3 4]]]
新形状: (2, 1, 2)
新维度: 3
``
(2, 2)
原始形状,维度索引 0 和 1。
axis=1在索引 1 之前插入。原来的维度 0 (大小 2) 保持为新的维度 0。新的维度 1 被插入并具有大小 1。原来的维度 1 (大小 2) 变成新的维度 2。形状变为
(2, 1, 2)`。 -
在
axis=2
处插入(等同于axis=-1
):
对于 2D 数组,合法axis
范围是-3
到2
。axis=2
是合法的,它表示在最后一个维度之后插入。
“`python
# 在 axis=2 (最后面) 插入新维度
arr_expanded_2_2d = np.expand_dims(arr_2d, axis=2)
print(“\nexpand_dims(arr_2d, axis=2):”)
print(arr_expanded_2_2d)
print(“新形状:”, arr_expanded_2_2d.shape) # (2, 2, 1)
print(“新维度:”, arr_expanded_2_2d.ndim)使用负数索引 axis=-1 (最后面)
arr_expanded_neg1_2d = np.expand_dims(arr_2d, axis=-1)
print(“\nexpand_dims(arr_2d, axis=-1):”)
print(arr_expanded_neg1_2d)
print(“新形状:”, arr_expanded_neg1_2d.shape) # (2, 2, 1)
print(“新维度:”, arr_expanded_neg1_2d.ndim)
输出:
expand_dims(arr_2d, axis=2):
[[[1] [2]][[3] [4]]]
新形状: (2, 2, 1)
新维度: 3expand_dims(arr_2d, axis=-1):
[[[1] [2]][[3] [4]]]
新形状: (2, 2, 1)
新维度: 3
``
(2, 2)
原始形状,维度索引 0 和 1。
axis=2或
axis=-1在索引 1 之后插入。原来的维度 0 (大小 2) 保持为新的维度 0。原来的维度 1 (大小 2) 保持为新的维度 1。新的维度 2 被插入并具有大小 1。形状变为
(2, 2, 1)`。
总结 axis
参数的行为:
对于形状为 (d0, d1, ..., dk, ..., dn-1)
的 n 维数组 a
:
np.expand_dims(a, axis=i)
(当i >= 0
) 会生成一个形状为(d0, d1, ..., di-1, 1, di, ..., dn-1)
的 n+1 维数组。新的维度被插入到索引i
处。np.expand_dims(a, axis=-i)
(当i > 0
) 会生成一个形状为(d0, d1, ..., dn-i-1, 1, dn-i, ..., dn-1)
的 n+1 维数组。新的维度被插入到距离末尾第i
个维度之前。axis=-1
插入在最后面,axis=-2
插入在倒数第二个位置之前,以此类推。- 合法的
axis
整数值范围是[-ndim - 1, ndim]
。在这个范围内,axis=i
等同于axis=i - ndim - 1
当i >= 0
,以及axis=i + ndim + 1
当i < -ndim - 1
(虽然-ndim - 1
是最小值,所以第二个等式只适用于大于-ndim - 1
的负数)。简单的记法是,正数i
插入在索引i
之前,负数-i
插入在距离末尾第i
个位置之前。
使用元组作为 axis
:
如前所述,使用元组可以在多个位置插入维度。插入的顺序是按照元组中 axis
值从小到大进行,但插入位置是相对于 当前 形状而言的。
“`python
arr_2d = np.array([[1, 2], [3, 4]]) # Shape (2, 2)
在 axis=0 和 axis=2 (相对于原始数组) 插入
步骤1: 在 axis=0 插入 (shape (1, 2, 2))
步骤2: 在新的数组 (1, 2, 2) 的 axis=2 插入
arr_expanded_tuple = np.expand_dims(arr_2d, axis=(0, 2))
print(“\nexpand_dims(arr_2d, axis=(0, 2)):”)
print(arr_expanded_tuple)
print(“新形状:”, arr_expanded_tuple.shape) # (1, 2, 1, 2)
print(“新维度:”, arr_expanded_tuple.ndim)
输出:
expand_dims(arr_2d, axis=(0, 2)):
[[[[1 2]]
[[3 4]]]]
新形状: (1, 2, 1, 2)
新维度: 4
``
(2, 2)
原始形状。
axis=0
1. 先在插入:结果形状
(1, 2, 2)。
(1, 2, 2)
2. 然后,在新的形状的
axis=2位置插入:新维度插入在索引 2 之前。原来的维度 0 (大小 1) 保持为新的维度 0。原来的维度 1 (大小 2) 保持为新的维度 1。新的维度 2 被插入 (大小 1)。原来的维度 2 (大小 2) 变成新的维度 3。最终形状
(1, 2, 1, 2)`。
这种使用元组的情况相对少见,且行为可能有些反直觉(相对于原始数组轴的相对位置)。更常见的是一次只通过整数 axis
参数增加一个维度。如果需要增加多个维度,可以连续调用 expand_dims
或使用其他方法。
expand_dims
的常见应用场景
expand_dims
函数的简洁性使其在多种常见场景下成为首选工具。
-
为深度学习模型准备输入数据:
许多深度学习模型(如使用 TensorFlow 或 PyTorch 构建的卷积神经网络 CNN)期望输入数据包含一个“批次大小”(batch size)的维度作为第一个维度。即使你只想处理一张图片或一个数据样本,也需要将其包装在一个大小为 1 的批次中。-
图片数据: 单张彩色图片的数据通常是
(高度, 宽度, 通道数)
的三维数组。模型可能期望(批次大小, 高度, 宽度, 通道数)
。
“`python
import numpy as np模拟一张 28×28 的灰度图片数据
img_gray = np.random.rand(28, 28) # Shape (28, 28)
模拟一张 28×28 的彩色图片数据
img_color = np.random.rand(28, 28, 3) # Shape (28, 28, 3)
为灰度图片增加批次维度和通道维度 (如果模型期望 (batch, height, width, channel))
先增加批次维度 (axis=0)
img_gray_batch = np.expand_dims(img_gray, axis=0) # Shape (1, 28, 28)
再增加通道维度 (axis=3 或 axis=-1)
img_gray_batch_channel = np.expand_dims(img_gray_batch, axis=-1) # Shape (1, 28, 28, 1)
print(f”灰度图转模型输入形状: {img_gray_batch_channel.shape}”)为彩色图片增加批次维度 (axis=0)
img_color_batch = np.expand_dims(img_color, axis=0) # Shape (1, 28, 28, 3)
print(f”彩图转模型输入形状: {img_color_batch.shape}”)
输出:
灰度图转模型输入形状: (1, 28, 28, 1)
彩图转模型输入形状: (1, 28, 28, 3)
``
expand_dims
这里我们使用了简洁地添加了批次维度。对于灰度图转
(batch, height, width, channel)形状,我们分了两步,先加批次维度,再加通道维度。使用
expand_dims` 更清晰地表达了“增加一个维度”的意图。 -
序列数据: 循环神经网络 RNN 或 Transformer 模型处理序列数据时,通常期望输入是
(批次大小, 序列长度, 特征数)
。单个序列数据(序列长度, 特征数)
需要增加批次维度。
“`python
# 模拟一个序列数据,序列长度为 50,每个时间步有 10 个特征
sequence_data = np.random.rand(50, 10) # Shape (50, 10)为序列数据增加批次维度 (axis=0)
sequence_batch = np.expand_dims(sequence_data, axis=0) # Shape (1, 50, 10)
print(f”序列数据转模型输入形状: {sequence_batch.shape}”)
输出:
序列数据转模型输入形状: (1, 50, 10)
“`
-
-
实现特定类型的广播操作:
广播是 NumPy 的强大特性,但有时需要手动调整数组形状才能满足广播条件。增加一个大小为 1 的维度是实现广播兼容性的常见手段。NumPy 的广播规则中重要的一条是:当两个数组的维度数不同时,维度数较小的数组会在其前部(左侧)填充大小为 1 的维度,直到两个数组维度数相同。然后比较对应维度的长度,如果相等或其中一个长度为 1,则兼容。考虑两个一维数组
a = np.array([10, 20, 30])
(形状(3,)
) 和b = np.array([1, 2, 3])
(形状(3,)
)。它们的乘积a * b
是元素级的[10, 40, 90]
(形状(3,)
)。如果我们想计算它们的“外积”(outer product),即得到一个 3×3 的矩阵,其中元素
M[i, j] = a[i] * b[j]
。这可以通过广播实现,但需要将一个数组变为列向量形状(3, 1)
,另一个变为行向量形状(1, 3)
。“`python
a = np.array([10, 20, 30]) # Shape (3,)
b = np.array([1, 2, 3]) # Shape (3,)将 a 变为列向量形状 (3, 1)
a_col = np.expand_dims(a, axis=1) # Shape (3, 1)
print(f”a 变为列向量形状: {a_col.shape}”)将 b 变为行向量形状 (1, 3)
b_row = np.expand_dims(b, axis=0) # Shape (1, 3)
print(f”b 变为行向量形状: {b_row.shape}”)计算外积: (3, 1) 和 (1, 3) 进行广播
outer_product = a_col * b_row
print(“\n外积结果:”)
print(outer_product)
print(“外积形状:”, outer_product.shape) # Shape (3, 3)
输出:
a 变为列向量形状: (3, 1)
b 变为行向量形状: (1, 3)外积结果:
[[10 20 30]
[20 40 60]
[30 60 90]]
外积形状: (3, 3)
``
expand_dims
在这个例子中,清晰地表达了我们的意图:将
a转换为带有新列维度的数组,将
b` 转换为带有新行维度的数组,从而实现广播。 -
数据可视化库的输入要求:
一些绘图库可能对输入数据的形状有特定要求。例如,一个函数可能期望输入是一个二维数组(样本数, 特征数)
,而你的原始数据是每个样本一个一维数组列表。你可以将每个一维数组通过expand_dims(sample, axis=0)
转换为(1, 特征数)
,然后使用np.concatenate
将它们堆叠起来形成(样本数, 特征数)
的二维数组。
expand_dims
与 reshape
的区别
reshape
是另一个常用的改变数组形状的方法,但它与 expand_dims
有本质区别。
expand_dims
: 只 在指定位置插入一个大小为 1 的新维度。它不改变原始数据的排列方式(内存布局通常保持一致, unless strides change significantly, yielding a view)。它使得数组的维度数量增加 1(或更多,如果axis
是元组)。新维度的长度总是 1。reshape
: 改变数组的形状,但 不改变 数组元素的总数。它可以增加或减少维度,或者改变现有维度的大小,只要新形状的总元素数与原始数组的总元素数相等。reshape
可以返回一个视图或一个副本,取决于新的形状是否与原始数组的内存布局兼容。
示例:
“`python
arr = np.array([1, 2, 3, 4, 5, 6]) # Shape (6,)
使用 expand_dims
arr_expanded = np.expand_dims(arr, axis=0) # Shape (1, 6)
print(f”expand_dims 结果形状: {arr_expanded.shape}”)
使用 reshape
arr_reshaped_row = arr.reshape(1, 6) # Shape (1, 6)
print(f”reshape(1, 6) 结果形状: {arr_reshaped_row.shape}”)
arr_reshaped_col = arr.reshape(6, 1) # Shape (6, 1)
print(f”reshape(6, 1) 结果形状: {arr_reshaped_col.shape}”)
reshape 可以完全改变形状,只要元素总数不变
arr_reshaped_2x3 = arr.reshape(2, 3) # Shape (2, 3)
print(f”reshape(2, 3) 结果形状: {arr_reshaped_2x3.shape}”)
expand_dims 只能增加大小为 1 的维度,不能将 (6,) 变为 (2, 3)
np.expand_dims(arr, axis=?) 只能得到 (1, 6) 或 (6, 1)
“`
总结区别:
特性 | numpy.expand_dims(a, axis) |
a.reshape(new_shape) |
---|---|---|
目的 | 在指定位置插入一个大小为 1 的新维度 | 将数组重塑为指定的形状 |
维度变化 | 增加 1 个维度 (或更多,如果 axis 是元组) |
可以增加、减少或保持维度数不变 |
元素总数 | 保持不变 | 必须保持不变 |
新维度大小 | 总是 1 | 由 new_shape 决定,可以是任意整数 (>0) |
灵活性 | 专门用于插入大小为 1 的维度 | 更通用,可以实现多种形状变换 |
常见应用 | 添加批次维度、通道维度、为广播准备 | 改变数组布局 (如扁平化、改变行列数) |
当你的目标是精确地“在某个位置加一个大小为 1 的维度”时,expand_dims
更具表达力,且风险较低(不会意外地将数据重新排列成错误的形状)。
expand_dims
与切片操作中的 None
或 np.newaxis
在 NumPy 中,使用切片操作符结合 None
或 np.newaxis
是另一种在特定位置添加大小为 1 维度的常用方法。np.newaxis
实际上是 None
的别名,两者效果完全一样,使用 np.newaxis
通常被认为更具可读性。
语法是:arr[..., np.newaxis, ...]
,其中 ...
是省略号,表示保留原始的所有维度,np.newaxis
所在的位置就是新维度插入的位置。
示例:
“`python
arr = np.array([10, 20, 30]) # Shape (3,)
使用 expand_dims 在 axis=0 插入
arr_expanded_0 = np.expand_dims(arr, axis=0) # Shape (1, 3)
使用 np.newaxis 在 axis=0 位置 (逗号前) 插入
arr_newaxis_0 = arr[np.newaxis, :] # Shape (1, 3)
print(f”expand_dims(arr, 0) 形状: {arr_expanded_0.shape}”)
print(f”arr[np.newaxis, :] 形状: {arr_newaxis_0.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_0, arr_newaxis_0)}”) # 值相等
使用 expand_dims 在 axis=1 插入
arr_expanded_1 = np.expand_dims(arr, axis=1) # Shape (3, 1)
使用 np.newaxis 在 axis=1 位置 (逗号后) 插入
arr_newaxis_1 = arr[:, np.newaxis] # Shape (3, 1)
print(f”\nexpand_dims(arr, 1) 形状: {arr_expanded_1.shape}”)
print(f”arr[:, np.newaxis] 形状: {arr_newaxis_1.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_1, arr_newaxis_1)}”) # 值相等
arr_2d = np.array([[1, 2], [3, 4]]) # Shape (2, 2)
在 axis=1 插入
arr_expanded_1_2d = np.expand_dims(arr_2d, axis=1) # Shape (2, 1, 2)
arr_newaxis_1_2d = arr_2d[:, np.newaxis, :] # Shape (2, 1, 2)
print(f”\nexpand_dims(arr_2d, 1) 形状: {arr_expanded_1_2d.shape}”)
print(f”arr_2d[:, np.newaxis, :] 形状: {arr_newaxis_1_2d.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_1_2d, arr_newaxis_1_2d)}”)
在 axis=2 插入
arr_expanded_2_2d = np.expand_dims(arr_2d, axis=2) # Shape (2, 2, 1)
arr_newaxis_2_2d = arr_2d[:, :, np.newaxis] # Shape (2, 2, 1)
print(f”\nexpand_dims(arr_2d, 2) 形状: {arr_expanded_2_2d.shape}”)
print(f”arr_2d[:, :, np.newaxis] 形状: {arr_newaxis_2_2d.shape}”)
print(f”两者是否相等: {np.array_equal(arr_expanded_2_2d, arr_newaxis_2_2d)}”)
使用 np.newaxis 可以同时添加多个维度
arr_newaxis_multi = arr_2d[np.newaxis, :, :, np.newaxis] # Shape (1, 2, 2, 1)
print(f”\narr_2d[np.newaxis, :, :, np.newaxis] 形状: {arr_newaxis_multi.shape}”)
等价于连续调用 expand_dims
arr_expanded_multi = np.expand_dims(np.expand_dims(arr_2d, axis=0), axis=3) # 或者 axis=-1 (相对于第二次expand_dims的结果)
print(f”等效的连续 expand_dims 形状: {arr_expanded_multi.shape}”)
print(f”两者是否相等: {np.array_equal(arr_newaxis_multi, arr_expanded_multi)}”)
输出:
expand_dims(arr, 0) 形状: (1, 3)
arr[np.newaxis, :] 形状: (1, 3)
两者是否相等: True
expand_dims(arr, 1) 形状: (3, 1)
arr[:, np.newaxis] 形状: (3, 1)
两者是否相等: True
expand_dims(arr_2d, 1) 形状: (2, 1, 2)
arr_2d[:, np.newaxis, :] 形状: (2, 1, 2)
两者是否相等: True
expand_dims(arr_2d, 2) 形状: (2, 2, 1)
arr_2d[:, :, np.newaxis] 形状: (2, 2, 1)
两者是否相等: True
arr_2d[np.newaxis, :, :, np.newaxis] 形状: (1, 2, 2, 1)
等效的连续 expand_dims 形状: (1, 2, 2, 1)
两者是否相等: True
“`
总结比较 expand_dims
和 np.newaxis
:
两者在功能上是等效的,都可以用于在指定位置插入大小为 1 的维度。选择使用哪个通常取决于个人偏好和上下文:
-
np.newaxis
/None
:- 优点: 语法简洁紧凑,特别适合在进行切片或索引操作的同时添加维度。可以在一个操作中添加多个维度 (
arr[None, :, None]
)。 - 缺点: 语法可能不如函数调用直观,特别是对于初学者。当轴索引计算复杂时容易出错。
- 优点: 语法简洁紧凑,特别适合在进行切片或索引操作的同时添加维度。可以在一个操作中添加多个维度 (
-
np.expand_dims
:- 优点: 函数调用形式,意图明确,可读性强。
axis
参数明确指定了新维度的位置,尤其适合动态计算轴索引的场景。 - 缺点: 一次函数调用只能添加一个维度(除非
axis
是元组,但元组axis
的行为需要额外理解)。添加多个维度需要连续调用。
- 优点: 函数调用形式,意图明确,可读性强。
在许多情况下,arr[..., np.newaxis]
的切片语法更为流行和常见,因为它与 NumPy 的索引系统紧密结合。然而,np.expand_dims
作为独立的函数,在强调“增加维度”这一操作本身时,提供了更好的可读性,并且在编写更复杂的、涉及到动态轴操作的代码时可能更易于管理。例如,如果你有一个变量 target_axis
存储了需要在哪个轴插入维度,使用 np.expand_dims(arr, axis=target_axis)
比构建一个复杂的切片元组要简单得多。
潜在的陷阱与注意事项
- 理解
axis
参数: 这是使用expand_dims
最容易出错的地方。记住axis=i
是在当前索引i
之前插入,负数索引-i
是在当前倒数第i
个位置之前插入。绘制数组的维度图可以帮助理解。 axis
的合法范围: 确保axis
的值在[-ndim - 1, ndim]
的范围内。超出范围会导致ValueError
。- 元组
axis
的行为: 如前所述,当axis
是元组时,维度是按元组值升序插入的,但插入位置是相对于 上一步操作后的数组形状。这可能与想象中相对于原始数组轴的位置不同。通常建议避免使用元组axis
,除非你完全理解其工作原理,或者通过连续调用expand_dims
来替代。 - 视图 vs 副本: 尽管
expand_dims
通常返回视图,但在某些复杂情况下或当原始数组的内存布局不标准时,它可能会返回副本。如果对性能有严格要求,并且需要修改原始数据,应该注意这一点。可以使用arr.flags['OWNDATA']
来检查返回的数组是否拥有自己的数据。
总结
numpy.expand_dims
是 NumPy 中一个虽小但功能强大的函数,专门用于在数组的指定位置插入一个大小为 1 的新维度。这个操作在准备数据输入、实现广播、调整数组布局等方面非常有用。
理解 axis
参数是掌握 expand_dims
的关键。通过明确指定新维度插入的前一个索引位置,我们可以精确控制数组形状的变化。
虽然 expand_dims
的功能可以通过 reshape
或使用 np.newaxis
进行切片来实现,但 expand_dims
作为独立的函数,在表达“增加一个维度”这一特定意图时具有更高的可读性和简洁性,尤其是在 axis
索引是变量时。
掌握 expand_dims
以及其他 NumPy 数组形状操作函数,是高效进行数值计算和数据处理的基础。通过本文的详细介绍和示例,希望你能够更深入地理解 expand_dims
,并在实际应用中灵活运用它。