def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout) # 这里用到缩放点积注意力
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# output的形状:(batch_size*num_heads,查询的个数,
# num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状:(batch_size,查询的个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
def attention_nhead():
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
output = attention(X, Y, Y, valid_lens)
print(output)
attention_nhead()
# 输出:
tensor([[[-3.6907e-01, -1.1405e-04, 3.2671e-01, -1.7356e-01, -8.1225e-01,
-3.7096e-01, 2.7797e-01, -2.6977e-01, -2.5845e-01, -2.3081e-01,
3.0618e-01, 2.7673e-01, -2.6381e-01, -8.4385e-02, 6.8697e-01,
-3.0869e-01, -2.6311e-01, 3.3698e-01, 2.0350e-02, -1.1740e-01,
-2.9579e-01, -2.3887e-01, -1.3595e-01, 1.6481e-01, 3.6974e-01,
-1.2254e-01, -4.8702e-01, -3.3106e-01, 1.9889e-01, 4.6272e-04,
-3.0664e-01, 1.0336e-01, 1.5175e-01, 5.1327e-02, -1.7456e-01,
1.0848e-01, -2.1586e-01, -1.3530e-01, 1.4878e-01, 2.2182e-01,
-1.8205e-01, 4.2394e-02, -1.2869e-01, -6.1095e-02, 1.1372e-01,
-2.4854e-01, 9.8994e-02, -4.2462e-01, -1.9857e-02, -1.0072e-01,
7.6214e-01, 1.4569e-01, 2.4027e-01, -1.4111e-01, -3.5483e-01,
1.2154e-02, -4.0619e-01, -1.7395e-01, 1.2091e-02, 1.2583e-01,
4.5608e-01, -2.2189e-01, 1.1187e-01, -2.2936e-01, 2.6352e-01,
-2.1522e-02, 1.7198e-01, 2.4890e-01, -5.9914e-01, -3.3339e-01,
-5.0526e-03, 2.5246e-01, -5.5496e-02, 8.2241e-02, 2.3885e-01,
-6.4767e-02, 4.5753e-01, 1.4007e-01, 3.2348e-01, -2.9186e-01,
-2.0273e-01, 7.9331e-01, 2.4528e-01, -2.3202e-01, 6.0938e-01,
-3.4037e-01, -3.0914e-01, 2.0632e-01, -1.1952e-01, -1.4625e-01,
5.5157e-01, -1.5517e-01, 5.0877e-01, 1.9026e-01, -3.7252e-02,
-1.7278e-01, -2.9345e-01, -1.2168e-01, 1.7021e-01, 7.7886e-01],
...
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.