完整的代码注释:https://github.com/Guadzilla/Paper_notebook/tree/main/SASRec
论文笔记:https://guadzilla.github.io/2021/11/03/SASRec/
collections.defaultdict(list)
1 | class collections.defaultdict(default_factory=None, /[, ...]) |
返回一个新的类似字典的对象。defaultdict
是内置 dict
类的子类。 它重载了一个方法并添加了一个可写的实例变量。
本对象包含一个名为 default_factory
的属性,构造时,第一个参数用于为该属性提供初始值,默认为 None。所有其他参数(包括关键字参数)都相当于传递给 dict 的构造函数。
使用defulydict(list)
实例化对象时, default_factory=list
,可以很轻松地将(键-值对组成的)序列转换为(键-列表组成的)字典:
1 | s = [('yellow', 1), ('blue', 2), ('yellow', 3), ('blue', 4), ('red', 1)] |
当字典中没有的键第一次出现时,python自动为其返回一个空列表,list.append()会将值添加进新列表;再次遇到相同的键时,list.append()将其它值再添加进该列表。
Python自定义多线程
1 | def random_neq(l, r, s): |
torch.tril()
1 | torch.tril(input, diagonal=0, *, out=None) → Tensor |
例:
1 | 3, 3) a = torch.randn( |
Python中的 ~ 波浪线运算符
~,用法只有一个那就是按位取反
Python 波浪线与补码_https://space.bilibili.com/59807853-CSDN博客_python 波浪线
torch.nn.MultiAttention
1 | torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None): |
对应公式:
计算公式:
1 | forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None) |
QKV比较常规,需要注意的是
- key_padding_mask参数,大小为(N,S),指定key中的哪些元素不做attention计算,即看作padding。注意,为True的位置不计算attention(是padding的地方不计算)
- attn_mask参数,
torch.nn.BCEWithLogitsLoss()
1 | forward(self, input: Tensor, target: Tensor) -> Tensor |
参数说明:
- input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
- target: Tensor of the same shape as input with values between 0 and 1
input:$x$ output:$y$
当 $y=1$ 时,$l_n=−log\sigma(x_n)$ ;当 $y=0$ 时,$l_n=−log(1-\sigma(x_n))$ 。
论文里使用了一个全1的矩阵pos_labels,和一个全0的矩阵neg_labels。正例标签值都为1(正确的item,ground truth应该是概率为1),负例标签值都为0(错误的item,ground truth应该是概率为0)。
1 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), \ |
torch.argsort()
1 | torch.argsort(input, dim=-1, descending=False) → LongTensor |
沿着指定dim从小到大(默认)排序元素,然后返回这些元素原来的下标。
1 | >>>t = torch.randint(1,10,(1,5)) |
numpy中的argmax、argmin、argwhere、argsort、argpartition函数 - 古明地盆 - 博客园 (cnblogs.com)
评价指标Hit Ratio、NDCG[1]
Hit Ratio
Evaluation Metrics. Given a user, each algorithm produces a ranked list of items. To assess the ranked list with the ground-truth item set (GT), we adopt Hit Ratio (HR), which has been commonly used in top-N evaluation . If a test item appears in the recommended list, it is deemed a hit. HR is calculated as:
NDCG
As the HR is recall-based metric, it does not reflect the accuracy of getting top ranks correct, which is crucial in many real-world applications. To address this, we also adopt Normalized Discounted Cumulative Gain (NDCG), which assigns higher importance to results at top ranks, scoring successively lower ranks with marginal fractional utility:
where ZK is the normalizer to ensure the perfect ranking has a value of 1; ri is the graded relevance of item at position i. We use the simple binary relevance for our work: ri = 1 if the item is in the test set, and 0 otherwise. For both metrics, larger values indicate better performance. In the evaluation, we calculate both metrics for each user in the test set, and report the average score.
代码实现:
1 | # evaluate on test set |