在 PyTorch 中,维度(dimension) 是描述张量形状的一种方式。维度操作是 PyTorch 中非常重要的功能,常用于调整张量的形状以适配各种计算需求。以下是常见的维度操作及其示例。
1. 维度的概念回顾
- 一个二维张量(矩阵)的形状是
(行数, 列数)
。 - 一个三维张量的形状是
(深度, 行数, 列数)
。 - 维度的索引从
0
开始,最外层是axis=0
,向内依次递增。
2. 维度的操作
(1) 求和(Sum)
sum(dim)
的作用是沿着指定的维度对张量求和,并 移除该维度。
示例:二维张量
import torchtensor = torch.tensor([[1, 2, 3],[4, 5, 6]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape) # (2, 3)
-
沿维度 0 求和(
dim=0
):sum_axis0 = tensor.sum(dim=0) print("沿维度 0 求和:\n", sum_axis0)
输出:
沿维度 0 求和:tensor([5, 7, 9])
解释:
- 维度 0 是行方向。
- 对每一列的元素求和:
- 第 0 列:
1 + 4 = 5
- 第 1 列:
2 + 5 = 7
- 第 2 列:
3 + 6 = 9
- 第 0 列:
-
沿维度 1 求和(
dim=1
):sum_axis1 = tensor.sum(dim=1) print("沿维度 1 求和:\n", sum_axis1)
输出:
沿维度 1 求和:tensor([6, 15])
解释:
- 维度 1 是列方向。
- 对每一行的元素求和:
- 第 0 行:
1 + 2 + 3 = 6
- 第 1 行:
4 + 5 + 6 = 15
- 第 0 行:
示例:三维张量
tensor = torch.tensor([[[1, 2], [3, 4]],[[5, 6], [7, 8]]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape) # (2, 2, 2)
-
沿维度 0 求和(
dim=0
):sum_axis0 = tensor.sum(dim=0) print("沿维度 0 求和:\n", sum_axis0)
输出:
沿维度 0 求和:tensor([[ 6, 8],[10, 12]])
解释:
- 维度 0 是最外层(矩阵的数量)。
- 对两个矩阵的对应位置元素求和:
[1, 2] + [5, 6] = [6, 8]
[3, 4] + [7, 8] = [10, 12]
(2) 增加维度(Unsqueeze)
unsqueeze(dim)
的作用是在指定维度上增加一个大小为 1 的维度。
示例:二维张量
tensor = torch.tensor([[1, 2, 3],[4, 5, 6]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape) # (2, 3)
-
在维度 0 增加维度:
unsqueeze_axis0 = tensor.unsqueeze(0) print("在维度 0 增加维度:\n", unsqueeze_axis0) print("形状:", unsqueeze_axis0.shape) # (1, 2, 3)
输出:
在维度 0 增加维度:tensor([[[1, 2, 3],[4, 5, 6]]]) 形状: (1, 2, 3)
-
在维度 1 增加维度:
unsqueeze_axis1 = tensor.unsqueeze(1) print("在维度 1 增加维度:\n", unsqueeze_axis1) print("形状:", unsqueeze_axis1.shape) # (2, 1, 3)
输出:
在维度 1 增加维度:tensor([[[1, 2, 3]],[[4, 5, 6]]]) 形状: (2, 1, 3)
(3) 移除维度(Squeeze)
squeeze(dim)
的作用是移除指定维度上大小为 1 的维度。
示例:三维张量
tensor = torch.tensor([[[1, 2, 3]]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape) # (1, 1, 3)
-
移除所有大小为 1 的维度:
squeeze_tensor = tensor.squeeze() print("移除所有大小为 1 的维度:\n", squeeze_tensor) print("形状:", squeeze_tensor.shape) # (3,)
输出:
移除所有大小为 1 的维度:tensor([1, 2, 3]) 形状: (3,)
-
移除指定维度:
squeeze_dim0 = tensor.squeeze(0) print("移除维度 0:\n", squeeze_dim0) print("形状:", squeeze_dim0.shape) # (1, 3)
输出:
移除维度 0:tensor([[1, 2, 3]]) 形状: (1, 3)
3. 总结
- 求和(
sum
):沿指定维度对元素求和,并移除该维度。 - 增加维度(
unsqueeze
):在指定维度上增加一个大小为 1 的维度。 - 移除维度(
squeeze
):移除指定维度上大小为 1 的维度。
通过这些操作,可以灵活调整张量的形状,使其适配各种计算需求!