详解 @ 符号在 PyTorch 中的矩阵乘法规则
在 PyTorch 和 NumPy 中,@ 符号被用作矩阵乘法运算符,它本质上等价于 torch.matmul() 或 numpy.matmul(),用于执行张量之间的矩阵乘法。
在本篇博客中,我们将深入探讨:
@运算符的基本概念@在不同维度张量上的计算规则@在(d, k) @ (d, 1)这种情况下的运算细节- PyTorch 自动广播机制
- 代码示例与直观理解
1. 什么是 @?
在 Python 3.5 之后,@ 被引入作为 矩阵乘法运算符,它在 NumPy 和 PyTorch 中与 matmul() 等价。例如:
import numpy as npA = np.array([[1, 2], [3, 4]])
B = np.array([[5], [6]])C = A @ B # 矩阵乘法
print(C)
输出:
[[17][39]]
等价于
C = np.matmul(A, B)
在 PyTorch 中,@ 也适用于张量计算:
import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5], [6]], dtype=torch.float32)C = A @ B # PyTorch 版本的矩阵乘法
print(C)
2. @ 在不同维度张量上的计算规则
2.1 规则概述
@ 的运算规则依赖于输入张量的维度:
- 两个标量(0D):返回标量
- 标量和张量:标量与张量的元素逐个相乘
- 一维向量(1D):
(N,) @ (N,) → 标量(点积)(N,) @ (N, M) → (M,)(左向量 × 矩阵)(N, M) @ (M,) → (N,)(矩阵 × 右向量)
- 二维矩阵(2D):
(N, M) @ (M, K) → (N, K)(标准矩阵乘法)
- 高维张量(≥3D):
(A, B, C) @ (C, D) → (A, B, D)(批量矩阵乘法)
3. 重点解析 (d, k) @ (d, 1)
在 PyTorch 中,如果 A.shape = (d, k),B.shape = (d, 1),A @ B 是 非法操作,因为矩阵乘法要求 A 的列数(k)等于 B 的行数(d),但这里 B 的形状 (d, 1) 无法与 (d, k) 匹配。
3.1 (d, k) @ (d, 1) 为什么不合法?
假设:
import torch
d, k = 4, 3A = torch.randn(d, k) # (4, 3)
B = torch.randn(d, 1) # (4, 1)C = A @ B # ❌ 错误:形状不匹配
会报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x1)
原因:
- 矩阵乘法规则:
A的列数(k)必须等于B的行数(d)。 - (d, k) @ (d, 1) 不符合这个规则,因为
d ≠ k。
3.2 如何让 (d, k) @ (d, 1) 变成合法操作?
我们需要 调整矩阵的形状,使其满足矩阵乘法的规则。
方法 1:交换操作数顺序
如果计算 B.T @ A:
C = B.T @ A # shape (1, d) @ (d, k) → (1, k)
就变成了合法操作。
方法 2:转置 A
如果我们计算:
C = A.T @ B # shape (k, d) @ (d, 1) → (k, 1)
这个计算是 合法的,因为 A.T.shape = (k, d),B.shape = (d, 1),满足矩阵乘法规则。
示例:
C = A.T @ B # (k, d) @ (d, 1) → (k, 1)
现在 A.T 变成 (k, d),B 仍然是 (d, 1),最终 C 的形状是 (k, 1)。
3.3 PyTorch 如何正确处理 (d, k) @ (d,)
在 PyTorch 代码中,我们常见这样的计算:
q = P_q @ x # (h, d, k) @ (d,)
为什么这里不需要转置 P_q?
x.shape = (d,),PyTorch 自动扩展为(d, 1)使其成为列向量- 计算
(d, k) @ (d, 1)是 非法的,PyTorch 自动调整计算规则 - PyTorch 实际执行的是
P_q.T @ x,确保计算正确 - 最终返回 (h, k),去掉了多余的维度
因此 PyTorch 不需要我们手动转置 P_q,它会自动处理 x 为列向量进行计算!
4. 代码示例
import torchd, k = 4, 3
torch.manual_seed(42)A = torch.randn(d, k) # (4, 3)
x = torch.randn(d) # (4,)# PyTorch 自动扩展 x,使其符合矩阵乘法规则
C = A.T @ x # (k, d) @ (d,) → (k,)print("A shape:", A.shape) # (4, 3)
print("x shape:", x.shape) # (4,)
print("C shape:", C.shape) # (3,)
5. 结论
@是 矩阵乘法运算符,等价于torch.matmul(A, B)- (d, k) @ (d, 1) 是不合法的矩阵乘法
- PyTorch 会自动扩展 (d,) → (d, 1) 并进行正确的矩阵计算
- (d, k) @ (d,) 实际等价于 (k, d) @ (d, 1),避免了显式转置
🚀 PyTorch 的 @ 计算规则很智能,能够自动扩展维度,让矩阵乘法符合数学规则! 🎯
在 q = P_q @ x 计算中,P_q.T 转置的是哪个维度?如何判断?
在 PyTorch 代码:
q = P_q @ x # (h, d, k) @ (d,)
核心问题:
P_q.shape = (h, d, k)x.shape = (d,)
为什么 不需要手动转置 P_q?以及 PyTorch 在计算 P_q @ x 时转置了哪个维度?
1. @ 运算规则
PyTorch 处理 torch.matmul(A, B) 时,遵循 广播机制 和 矩阵乘法规则:
- 最后两个维度 参与矩阵乘法
- 如果
B是 1D 张量(即B.shape = (d,)),PyTorch 会自动扩展为(d, 1)但不会影响计算逻辑
2. q = P_q @ x 具体计算
2.1 P_q.shape = (h, d, k), x.shape = (d,)
按照 PyTorch 规则:
- 扩展
x形状x.shape = (d,)自动扩展为(d, 1),使其符合矩阵乘法规则:
x = x.unsqueeze(-1) # (d,) → (d, 1) - 选择
P_q参与矩阵乘法的维度P_q.shape = (h, d, k),表示:h:注意力头数(不参与矩阵计算)d:输入维度(与x匹配)k:查询维度(计算目标)
P_q @ x的计算目标是:
( h , d , k ) @ ( d , 1 ) (h, d, k) @ (d, 1) (h,d,k)@(d,1)
需要P_q的d维度与x的d维度对齐,才能进行矩阵乘法。
2.2 PyTorch 自动调整 P_q 计算方式
PyTorch 不会转置完整的 P_q,但会 调整最后两个维度 (d, k) 进行计算:
- 等价于
q = ( h , k , d ) @ ( d , 1 ) = ( h , k , 1 ) q = (h, k, d) @ (d, 1) = (h, k, 1) q=(h,k,d)@(d,1)=(h,k,1) - 等价于
其中q = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1)) # shape (h, k, 1)P_q.transpose(-2, -1)交换(d, k)→(k, d)。
最终 PyTorch 计算:
q = (h, d, k) @ (d,) = (h, k)
其中 PyTorch 自动去除了 1 维度,返回 (h, k),而不是 (h, k, 1)。
3. 如何判断 PyTorch 进行了哪些维度调整?
我们可以用 transpose() 和 matmul() 手动验证:
import torchh, d, k = 2, 4, 3 # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42)P_q = torch.randn(h, d, k) # shape (h, d, k)
x = torch.randn(d) # shape (d,)# PyTorch 计算
q1 = P_q @ x # (h, d, k) @ (d,) → (h, k)# 手动转置 + matmul
q2 = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1)).squeeze(-1) # (h, k)print("q1 shape:", q1.shape) # (h, k)
print("q2 shape:", q2.shape) # (h, k)
print(torch.allclose(q1, q2)) # True
结果:
q1 shape: torch.Size([2, 3])
q2 shape: torch.Size([2, 3])
True
说明 PyTorch 自动进行了 P_q.transpose(-2, -1),使 d 维度匹配 x 的 d 维度。
4. 结论
💡 PyTorch 只会转置 P_q 的 d, k 维度,确保矩阵乘法合法,但不会改变 h 维度。
判断 PyTorch 何时自动调整维度
| 操作 | 等效 PyTorch 计算 |
|---|---|
(d, k) @ (d,) | 自动转置 (d, k) 变 (k, d), 计算 (k, d) @ (d, 1) |
(h, d, k) @ (d,) | 自动调整 (d, k) 变 (k, d), 计算 (h, k, d) @ (d, 1) |
(d, k) @ (k, 1) | 直接符合矩阵乘法规则,正常计算 |
(h, d, k) @ (k, 1) | 符合矩阵乘法规则,正常计算 |
5. 关键点总结
✅ P_q 的 d, k 维度会被 PyTorch 自动调整,以匹配 x.shape = (d,)
✅ PyTorch 计算 (h, d, k) @ (d,),本质等价于 P_q.transpose(-2, -1) @ x.unsqueeze(-1)
✅ 最终 q.shape = (h, k),符合多头注意力计算要求
🚀 PyTorch 的 @ 操作非常智能,会自动调整张量的形状,使矩阵乘法符合数学规则! 🎯
后记
2025年2月23日07点49分于上海,在GPT4o大模型辅助下完成。
