AlphaFold3的Attention类
是一个多头注意力类,用于执行标准的多头注意力计算,同时支持 AlphaFold3 特有的初始化方法和其他高级特性。
源代码:
class Attention(nn.Module):"""Standard multi-head attention using AlphaFold's default layerinitialization. Allows multiple bias vectors."""def __init__(self,c_q: int,c_k: int,c_v: int,c_hidden: int,no_heads: int,gating: bool = True,residual: bool = True,proj_q_w_bias: bool = False,):"""Args:c_q:Input dimension of query datac_k:Input dimension of key datac_v:Input dimension of value datac_hidden:Per-head hidden dimensionno_heads:Number of attention headsgating:Whether the output should be gated using query dataresidual:If the output is residual, then the final linear layer is initialized tozeros so that the residual layer acts as the identity at initialization.proj_q_w_bias:Whether to project the Q vectors with a Linear layer that uses a bias"""super(Attention, self).__init__()self.c_q = c_qself.c_k = c_kself.c_v = c_vself.c_hidden = c_hiddenself.no_heads = no_headsself.gating = gatingsplit_heads = nn.Unflatten(dim=-1, unflattened_size=(self.no_heads, self.c_hidden))# The qkv linear layers project no_heads * c_hidden and then split the dimensionslinear_q_class = Linear if proj_q_w_bias else LinearNoBiasself.linear_q = nn.Sequential(linear_q_class(self.c_q, self.c_hidden * self.no_heads, init="glorot"),split_heads)self.linear_k = nn.Sequential(LinearNoBias(self.c_k, self.c_hidden * self.no_heads, init="glorot"),split_heads)self.linear_v = nn.Sequential(LinearNoBias(self.c_v, self.c_hidden * self.no_heads, init="glorot"),split_heads)self.linear_o = LinearNoBias(self.c_hidden * self.no_heads, self.c_q, init="final" if residual else "default")self.to_gamma = Noneif self.gating:self.to_gamma = nn.Sequential(LinearNoBias(self.c_q, self.c_hidden * self.no_heads, init="gating"),