0%

TiSASRec代码笔记


完整的代码注释:https://github.com/Guadzilla/Paper_notebook/tree/main/TiSASRec

论文笔记:https://guadzilla.github.io/2021/11/18/TiSASRec/


squeeze, unsqueeze, repeat ,expand

torch.squeeze(input,dim,*,out) —>Tensor

squeeze:挤压,捏

与unsqueeze操作相反,在指定dim处加入一维,如果dim未指定,则所有为1的维度去掉。

torch.unsqueeze(input,dim) —> Tensor

unsqueeze:挤压的反义词,膨胀

与squeeze操作相反,返回一个新张量,在原来张量的指定dim处加入一维。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
x = torch.tensor([[1, 2, 3, 4],		# x.shape=(2,4) ,有三处可以插入维度 _,1,_,4,_
[5, 6, 7, 8]])

torch.unsqueeze(x, 0).shape # (_,2,_,4,_),在第0维度(最左边)插入1维 = (1,2,4)
# torch.Size([1, 2, 4])

torch.unsqueeze(x, 1).shape # (_,2,_,4,_),在第1维度(中间的)插入1维 = (2,1,4)
# torch.Size([2, 1, 4])

torch.unsqueeze(x, 2).shape # (_,2,_,4,_),在第2维度(最右边)插入1维 = (2,1,4)
# torch.Size([2, 4, 1])

y = torch.unsqueeze(x, -1).unsqueeze(-1) # 在最后填两个为1的维度
y.shape
# torch.Size([2, 4, 1, 1])
y.squeeze().shape # squeeze不指定dim,会去掉所有size=1的维度
# torch.Size([2, 4])

torch.repeat(*size)

沿着指定的维度重复这个张量。类似numpy.tile(),地板铺(把tensor当成一块地板,按形状铺)。

torch.expand(*sizes)

将单个维度拓展成更大维度,和repeat不一样。

1
2
3
4
5
6
7
8
9
10
x = torch.tensor([1,2,3])
x
# tensor([1, 2, 3])
x.repeat(2,3)
# tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3], # x作为地板,被重复铺了(2,3)次
# [1, 2, 3, 1, 2, 3, 1, 2, 3]])
x.expand(2,3)
# tensor([[1, 2, 3], # x被拓展成(2,3)
# [1, 2, 3]])

实际代码:

1
2
3
4
5
6
7
8
9
time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1)
# 1.unsqueeze():time_mask.shape=(batch_size,maxlen) ——> (batch_size,maxlen,1),最后一个维度填1
# 2.repeat():(batch_size,maxlen,1) ——> (self.head_num*batch_size,maxlen,1),第一个维度乘倍数
time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1]) # 这里attn_weights.shape[-1]=maxlen
# 3.(self.head_num*batch_size,maxlen,1) ——>(self.head_num*batch_size,maxlen,maxlen)


attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
# (maxlen,maxlen) ——> (1,maxlen,maxlen) ——> (batch_size,maxlen,maxlen)

手动多头注意力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

class TimeAwareMultiHeadAttention(torch.nn.Module):
# required homebrewed mha layer for Ti/SASRec experiments
def __init__(self, hidden_size, head_num, dropout_rate, dev):
super(TimeAwareMultiHeadAttention, self).__init__()
self.Q_w = torch.nn.Linear(hidden_size, hidden_size)
self.K_w = torch.nn.Linear(hidden_size, hidden_size)
self.V_w = torch.nn.Linear(hidden_size, hidden_size)

self.dropout = torch.nn.Dropout(p=dropout_rate)
self.softmax = torch.nn.Softmax(dim=-1)

self.hidden_size = hidden_size
self.head_num = head_num
self.head_size = hidden_size // head_num
self.dropout_rate = dropout_rate
self.dev = dev

def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matrix_V, abs_pos_K, abs_pos_V):
# time_mask: padding item的mask, attn_mask: 为了causality的mask,下三角
Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys)

# head dim * batch dim for parallelization (h*N, T, C/h)
# 即(batch_size, maxlen, hidden_units) ----> (batch_size*3, maxlen, hidden_units/3)
# (batch_size, maxlen, maxlen, hidden_units) ----> (batch_size*3, maxlen, maxlen, hidden_units/3)
Q_ = torch.cat(torch.split(Q, self.head_size, dim=2), dim=0)
K_ = torch.cat(torch.split(K, self.head_size, dim=2), dim=0)
V_ = torch.cat(torch.split(V, self.head_size, dim=2), dim=0)

time_matrix_K_ = torch.cat(torch.split(time_matrix_K, self.head_size, dim=3), dim=0)
time_matrix_V_ = torch.cat(torch.split(time_matrix_V, self.head_size, dim=3), dim=0)
abs_pos_K_ = torch.cat(torch.split(abs_pos_K, self.head_size, dim=2), dim=0)
abs_pos_V_ = torch.cat(torch.split(abs_pos_V, self.head_size, dim=2), dim=0)

# batched channel wise matmul to gen attention weights ---公式(8)
attn_weights = Q_.matmul(torch.transpose(K_, 1, 2))
attn_weights += Q_.matmul(torch.transpose(abs_pos_K_, 1, 2))
attn_weights += time_matrix_K_.matmul(Q_.unsqueeze(-1)).squeeze(-1)

# seq length adaptive scaling ---公式(8)
attn_weights = attn_weights / (K_.shape[-1] ** 0.5)

# key masking, -2^32 lead to leaking, inf lead to nan
# 0 * inf = nan, then reduce_sum([nan,...]) = nan

# time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
# 会报错,必须按下面的1.2.3.
time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1)
# 1.unsqueeze():time_mask.shape=(batch_size,maxlen) ——> (batch_size,maxlen,1),最后一个维度填1
# 2.repeat():(batch_size,maxlen,1) ——> (self.head_num*batch_size,maxlen,1),第一个维度乘倍数
time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1])
# 3.(self.head_num*batch_size,maxlen,1) ——>(self.head_num*batch_size,maxlen,maxlen)
# tips:attn_weights= (B,maxlen,maxlen),每个batch中size=(maxlen,maxlen),每行表示某个item对其它所有item的atten矩阵
# time_mask是对padding的item做mask,本来是(B,maxlen,1),每个batch中size=(maxlen,1)
# expand成(B,maxlen,maxlen)才能把attn里padding的物品,即对应行都mask掉

attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
# (maxlen,maxlen) ——> (1,maxlen,maxlen) ——> (batch_size,maxlen,maxlen)
# padding取负无穷是因为底下要用softmax,以e为底的负无穷接近0
paddings = torch.ones(attn_weights.shape) * (-2**32+1) # -1e23 # float('-inf'),
paddings = paddings.to(self.dev)

# 这两步一起为了mask掉不用的attention计算,第一步是mask掉padding的items,第二是为了因果关系mask掉afterwards的items
attn_weights = torch.where(time_mask, paddings, attn_weights) # True:pick padding
attn_weights = torch.where(attn_mask, paddings, attn_weights) # enforcing causality

attn_weights = self.softmax(attn_weights) # ---公式(7)
attn_weights = self.dropout(attn_weights)

# ---公式(6),把alpha放进去乘了
outputs = attn_weights.matmul(V_)
outputs += attn_weights.matmul(abs_pos_V_)
outputs += attn_weights.unsqueeze(2).matmul(time_matrix_V_).reshape(outputs.shape)#.squeeze(2)

# (num_head * N, T, C / num_head) -> (N, T, C)
outputs = torch.cat(torch.split(outputs, Q.shape[0], dim=0), dim=2) # div batch_size

return outputs