• A Self-Attentive model for Knowledge Tracing论文笔记和代码解析


    原文链接和代码链接A Self-Attentive model for Knowledge Tracing | Papers With Code

    motivation:传统方法面临着处理稀疏数据时不能很好地泛化的问题。

    本文提出了一种基于自注意力机制的知识追踪模型 Self Attentive Knowledge Tracing (SAKT)。其本质是用 Transformer 的 encoder 部分来做序列任务。具体从学生过去的活动中识别出与给定的KC相关的KC,并根据所选KC相对较少的KC预测他/她的掌握情况。由于预测是基于相对较少的过去活动,它比基于RNN的方法更好地处理数据稀疏性问题。

    模型结构

     

    输入编码

    交互信息 x_{t}=\left (e_{t},r_{t} \right ) 通过公式y_{t}=e_{t}+r_{t}\times E 转变成一个数字,总量为 2E。

    我们 用Interaction embedding matrix 训练一个交互嵌入矩阵,M\in R^{2E \times d}
    被用来为序列中的每个元素s_{i}

     Exercise 编码 利用 exercise embedding matrix训练练习嵌入矩阵,E\in R^{E \times d},每行代表一个题目ei 

    Position Encoding

    自动学习P \in R^{E \times d},n 是序列长度。

    最终编码层的输出如下

    1. self.qa_embedding = nn.Embedding(
    2. 2 * n_skill + 2, self.qa_embed_dim, padding_idx=2 * n_skill + 1
    3. )
    4. self.pos_embedding = nn.Embedding(self.max_len, self.pos_embed_dim)
    5. #定义
    6. #计算
    7. qa = self.qa_embedding(qa)
    8. pos_id = torch.arange(qa.size(1)).unsqueeze(0).to(self.device)
    9. pos_x = self.pos_embedding(pos_id)
    10. qa = qa + pos_x

     注意力机制

    Self-attention layer采用scaled dotproduct attention mechanism。

    Self-attention的query、key和value分别为:

    1. self.multi_attention = nn.MultiheadAttention(
    2. embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
    3. )
    4. attention_out, _ = self.multi_attention(q, qa, qa, attn_mask=attention_mask)

     

     Causality:因果关系也是mask 避免未来交互对现在的

    Feed Forward layer

    用一个简单的前向传播网络将self-attention的输出进行前向传播。

    1. class FFN(nn.Module):
    2. def __init__(self, state_size=200, dropout=0.2):
    3. super(FFN, self).__init__()
    4. self.state_size = state_size
    5. self.dropout = dropout
    6. self.lr1 = nn.Linear(self.state_size, self.state_size)
    7. self.relu = nn.ReLU()
    8. self.lr2 = nn.Linear(self.state_size, self.state_size)
    9. self.dropout = nn.Dropout(self.dropout)
    10. def forward(self, x):
    11. x = self.lr1(x)
    12. x = self.relu(x)
    13. x = self.lr2(x)
    14. return self.dropout(x)

    剩余连接:剩余连接[2]用于将底层特征传播到高层。因此,如果低层特征对于预测很重要,那么剩余连接将有助于将它们传播到执行预测的最终层。在KT的背景下,学生尝试练习属于某个特定概念的练习来强化这个概念。因此,剩余连接有助于将最近解决的练习的嵌入传播到最终层,使模型更容易利用低层信息。在自我注意层和前馈层之后应用剩余连接。

    层标准化:在[1]中,研究表明,规范化特征输入有助于稳定和加速神经网络。我们在我们的架构中使用了层规范化目的层在自我注意层和前馈层也应用了归一化。

    1. attention_out = self.layer_norm1(attention_out + q)# Residual connection ; added excercise embd as residual because previous ex may have imp info, suggested in paper.
    2. attention_out = attention_out.permute(1, 0, 2)
    3. x = self.ffn(attention_out)
    4. x = self.dropout_layer(x)
    5. x = self.layer_norm2(x + attention_out)# Layer norm and Residual connection

    Prediction layer

    self-attention的输出经过前向传播后得到矩阵F,预测层是一个全连接层,最后经过sigmod激活函数,输出每个question的概率

     模型的目标是预测用户答题的对错情况,利用cross entropy loss计算(y_true, y_pred)

    实验

    1. # -*- coding:utf-8 -*-
    2. """
    3. Reference: A Self-Attentive model for Knowledge Tracing (https://arxiv.org/abs/1907.06837)
    4. """
    5. import torch
    6. import torch.nn as nn
    7. import deepkt.utils
    8. import deepkt.layer
    9. def future_mask(seq_length):
    10. mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype("bool")
    11. return torch.from_numpy(mask)
    12. class FFN(nn.Module):
    13. def __init__(self, state_size=200, dropout=0.2):
    14. super(FFN, self).__init__()
    15. self.state_size = state_size
    16. self.dropout = dropout
    17. self.lr1 = nn.Linear(self.state_size, self.state_size)
    18. self.relu = nn.ReLU()
    19. self.lr2 = nn.Linear(self.state_size, self.state_size)
    20. self.dropout = nn.Dropout(self.dropout)
    21. def forward(self, x):
    22. x = self.lr1(x)
    23. x = self.relu(x)
    24. x = self.lr2(x)
    25. return self.dropout(x)
    26. class SAKTModel(nn.Module):
    27. def __init__(
    28. self, n_skill, embed_dim, dropout, num_heads=4, max_len=64, device="cpu"
    29. ):
    30. super(SAKTModel, self).__init__()
    31. self.n_skill = n_skill
    32. self.q_embed_dim = embed_dim
    33. self.qa_embed_dim = embed_dim
    34. self.pos_embed_dim = embed_dim
    35. self.embed_dim = embed_dim
    36. self.dropout = dropout
    37. self.num_heads = num_heads
    38. self.max_len = max_len
    39. self.device = device
    40. self.q_embedding = nn.Embedding(
    41. n_skill + 1, self.q_embed_dim, padding_idx=n_skill
    42. )
    43. self.qa_embedding = nn.Embedding(
    44. 2 * n_skill + 2, self.qa_embed_dim, padding_idx=2 * n_skill + 1
    45. )
    46. self.pos_embedding = nn.Embedding(self.max_len, self.pos_embed_dim)
    47. self.multi_attention = nn.MultiheadAttention(
    48. embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
    49. )
    50. self.key_linear = nn.Linear(self.embed_dim, self.embed_dim)
    51. self.value_linear = nn.Linear(self.embed_dim, self.embed_dim)
    52. self.query_linear = nn.Linear(self.embed_dim, self.embed_dim)
    53. self.layer_norm1 = nn.LayerNorm(self.embed_dim)
    54. self.layer_norm2 = nn.LayerNorm(self.embed_dim)
    55. self.dropout_layer = nn.Dropout(self.dropout)
    56. self.ffn = FFN(self.embed_dim)
    57. self.pred = nn.Linear(self.embed_dim, 1, bias=True)
    58. def forward(self, q, qa):
    59. qa = self.qa_embedding(qa)
    60. pos_id = torch.arange(qa.size(1)).unsqueeze(0).to(self.device)
    61. pos_x = self.pos_embedding(pos_id)
    62. qa = qa + pos_x
    63. q = self.q_embedding(q)
    64. q = q.permute(1, 0, 2)
    65. qa = qa.permute(1, 0, 2)
    66. attention_mask = future_mask(q.size(0)).to(self.device)
    67. attention_out, _ = self.multi_attention(q, qa, qa, attn_mask=attention_mask)
    68. attention_out = self.layer_norm1(attention_out + q)# Residual connection ; added excercise embd as residual because previous ex may have imp info, suggested in paper.
    69. attention_out = attention_out.permute(1, 0, 2)
    70. x = self.ffn(attention_out)
    71. x = self.dropout_layer(x)
    72. x = self.layer_norm2(x + attention_out)# Layer norm and Residual connection
    73. x = self.pred(x)
    74. return x.squeeze(-1), None

  • 相关阅读:
    MySQL数据库技术笔记(6)
    集合框架:Set集合的特点、HashSet集合的底层原理、哈希表、实现去重复
    java培训技术自定义类型转换器示例
    【暑期集训第一周:搜索】【DFS&&BFS】
    调研主板,树莓派 VS RK3288板子,还是 RK的主板香,但是只支持 anrdoid 7系统,估计也有刷机成 armbian或者
    Python入门篇之循环结构
    线程间的调度顺序
    Android学习笔记 77. 可点击图片
    Java中的数组
    VoLTE端到端业务详解 | 应用实例二
  • 原文地址:https://blog.csdn.net/sereasuesue/article/details/128110283