1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| class TimeAwareMultiHeadAttention(torch.nn.Module): def __init__(self, hidden_size, head_num, dropout_rate, dev): super(TimeAwareMultiHeadAttention, self).__init__() self.Q_w = torch.nn.Linear(hidden_size, hidden_size) self.K_w = torch.nn.Linear(hidden_size, hidden_size) self.V_w = torch.nn.Linear(hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(p=dropout_rate) self.softmax = torch.nn.Softmax(dim=-1)
self.hidden_size = hidden_size self.head_num = head_num self.head_size = hidden_size // head_num self.dropout_rate = dropout_rate self.dev = dev
def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matrix_V, abs_pos_K, abs_pos_V): Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys)
Q_ = torch.cat(torch.split(Q, self.head_size, dim=2), dim=0) K_ = torch.cat(torch.split(K, self.head_size, dim=2), dim=0) V_ = torch.cat(torch.split(V, self.head_size, dim=2), dim=0)
time_matrix_K_ = torch.cat(torch.split(time_matrix_K, self.head_size, dim=3), dim=0) time_matrix_V_ = torch.cat(torch.split(time_matrix_V, self.head_size, dim=3), dim=0) abs_pos_K_ = torch.cat(torch.split(abs_pos_K, self.head_size, dim=2), dim=0) abs_pos_V_ = torch.cat(torch.split(abs_pos_V, self.head_size, dim=2), dim=0)
attn_weights = Q_.matmul(torch.transpose(K_, 1, 2)) attn_weights += Q_.matmul(torch.transpose(abs_pos_K_, 1, 2)) attn_weights += time_matrix_K_.matmul(Q_.unsqueeze(-1)).squeeze(-1)
attn_weights = attn_weights / (K_.shape[-1] ** 0.5)
time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1) time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1])
attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1) paddings = torch.ones(attn_weights.shape) * (-2**32+1) paddings = paddings.to(self.dev)
attn_weights = torch.where(time_mask, paddings, attn_weights) attn_weights = torch.where(attn_mask, paddings, attn_weights)
attn_weights = self.softmax(attn_weights) attn_weights = self.dropout(attn_weights)
outputs = attn_weights.matmul(V_) outputs += attn_weights.matmul(abs_pos_V_) outputs += attn_weights.unsqueeze(2).matmul(time_matrix_V_).reshape(outputs.shape)
outputs = torch.cat(torch.split(outputs, Q.shape[0], dim=0), dim=2)
return outputs
|