- 作者:老汪软件技巧
- 发表时间:2024-10-09 11:02
- 浏览量:
简单记录一下。
推荐系统中存在各式各样的序列特征,比如用户点击行为序列、曝光序列、加购序列等等,如何建模这些序列与target item之间的关系是一个比较重要的研究问题。
引入注意力机制的方法有如下几种:
self attention
self attention 是在同一个序列内部计算注意力权重的机制。每个位置的元素都可以关注同序列中的其他位置,让每个位置的词都能够获取整个序列的全局信息。对于self attention而言,其QKV均从行为序列中抽取而来,通过注意力机制表征用户交互历史的迁移性和相关性。
import torch
import torch.nn.functional as F
class SelfAttention(torch.nn.Module):
def __init__(self, d_model):
super(SelfAttention, self).__init__()
self.scale = d_model ** -0.5
def forward(self, query, key, value):
# query, key, value: (batch_size, seq_len, d_model)
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
target attention
target attention 是在不同的序列间进行的注意力机制。它关注输入序列(源序列)与输出序列之间的关系。对于target attention而言,其Q从target item抽取得到,KV从行为序列中抽取得到,刻画的是target item和候选物品的交叉,忽略了行为序列内部各元素的依赖关系。
import torch
import torch.nn.functional as F
class TargetAttention(torch.nn.Module):
def __init__(self, d_model):
super(TargetAttention, self).__init__()
self.scale = d_model ** -0.5
def forward(self, query, key, value):
# query: (batch_size, tgt_len, d_model)
# key: (batch_size, src_len, d_model)
# value: (batch_size, src_len, d_model)
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
dual attention
dual attention弥补以上两种注意力机制的缺点,先self attention,后target attention,从行为序列中提取出来的用户兴趣向量,既能反映候选物料与历史记忆之间的相关性,又能反映不同历史记忆之间的依赖性。
import torch
import torch.nn.functional as F
class MultiLevelAttention(torch.nn.Module):
def __init__(self, d_model):
super(MultiLevelAttention, self).__init__()
self.self_attention = SelfAttention(d_model)
self.target_attention = TargetAttention(d_model)
def forward(self, target_seq, source_seq):
# 先进行self attention,再进行target attention
self_attn_output, _ = self.self_attention(target_seq, target_seq, target_seq)
target_attn_output, _ = self.target_attention(self_attn_output, source_seq, source_seq)
return target_attn_output
参考文献: