• molecular-graph-bert(三)

    基于分子图的 BERT 模型,原文:MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction,原文解析:MG-BERT | 利用 无监督 原子表示学习 预测分子性质 | 在分子图上应用BERT | GNN | 无监督学习(掩蔽原子预训练) | attention,代码:Molecular-graph-BERT。本文在前两篇分析的基础上看 attention__visualize 部分


    medium = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'medium_weights','addH':True}
    rch = medium  ## small 3 4 128   medium: 6 6  256     large:  12 8 516
    trained_epoch = 8
    num_layers = arch['num_layers']
    num_heads = arch['num_heads']
    d_model = arch['d_model']
    addH = arch['addH']
    dff = d_model * 2
    vocab_size = 17
    dropout_rate = 0.1
    seed = 7
    task = 'logD'
    df = pd.read_csv('data/reg/logD.txt',sep='\t')
    sml_list = df['SMILES'].tolist()
    inference_dataset = Inference_Dataset(['C=C(CC)C(=O)c1ccc(OCC(=O)[O-])c(Cl)c1Cl',
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 预测 logD 的回归任务,构建 Inference_Dataset


    class Inference_Dataset(object):
        def __init__(self,sml_list,max_len=100,addH=True):
            self.vocab = str2num
            self.devocab = num2str
            self.sml_list = [i for i in sml_list if len(i)<max_len]
            self.addH =  addH
        def get_data(self):
            self.dataset = tf.data.Dataset.from_tensor_slices((self.sml_list,))
            self.dataset = self.dataset.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
                tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]),tf.TensorShape([None]))).cache().prefetch(20)
            return self.dataset
        def numerical_smiles(self, smiles):
            smiles_origin = smiles
            smiles = smiles.numpy().decode()
            atoms_list, adjoin_matrix = smiles2adjoin(smiles,explicit_hydrogens=self.addH)
            atoms_list = [''] + atoms_list
            nums_list =  [str2num.get(i,str2num['']) for i in atoms_list]
            temp = np.ones((len(nums_list),len(nums_list)))
            temp[1:,1:] = adjoin_matrix
            adjoin_matrix = (1-temp)*(-1e9)
            x = np.array(nums_list).astype('int64')
            return x, adjoin_matrix,[smiles], atoms_list
        def tf_numerical_smiles(self, smiles):
            x,adjoin_matrix,smiles,atom_list = tf.py_function(self.numerical_smiles, [smiles], [tf.int64, tf.float32,tf.string, tf.string])
            return x, adjoin_matrix,smiles,atom_list
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 将字符串转换成数据集,处理与训练数据集基本一致
    x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
    seq = tf.cast(tf.math.equal(x, 0), tf.float32)
    mask = seq[:, tf.newaxis, tf.newaxis, :]
    print(x, adjoin_matrix, smiles ,atom_list,mask)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 运行结果如下
    [[16  2  2  2  2  2  4  2  2  2  2  4  2  2  4  4  2  7  2  7  1  1  1  1
       1  1  1  1  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  3  2  2  2  2  4  3  2  2  2  2  2  4  2  3  4  2  2  2  2  2  2
       5  5  5  2  2  2  2  2  4  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1
       1  1  1  1  1  1  1  1  1  1  1  1  1  1  0  0]
     [16  2  2  4  3  2  2  2  2  4  2  2  1  1  1  1  1  1  1  1  1  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  2  4  2  2  2  2  2  2  3  4  4  2  2  2  2  4  2  2  2  2  2  2
       4  2  4  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  2  2  2  2  2  3  2  6  4  2  2  2  2  2  2  2  2  2  2  2  1  1
       1  1  1  1  1  1  1  1  1  1  1  1  1  0  0  0  0  0  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  2  4  3  2  3  3  2  6  3  4  4  6  1  1  1  1  1  1  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  2  4  2  2  2  2  6  4  4  3  2  4  3  2  2  2  2  2  2  2  2  1
       1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  3  2  2  2  7  2  2  2  2  2  2  3  2  3  2  2  2  2  2  2  1  1  1
       1  1  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
       0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
     [16  2  3  2  4  2  2  2  2  2  2  2  2  7  2  3  2  2  2  2  2  4  2  2
       2  2  2  5  6  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
       1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1]], shape=(9, 64), dtype=int64)
    [[[-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]]
     [[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ... -1.e+09  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ... -1.e+09  0.e+00  0.e+00]
      [-0.e+00 -1.e+09 -1.e+09 ... -0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]]
     [[-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]]
     [[-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]]
     [[-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]
      [ 0.e+00  0.e+00  0.e+00 ...  0.e+00  0.e+00  0.e+00]]
     [[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00 -0.e+00 -0.e+00]
      [-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
      [-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
      [-0.e+00 -1.e+09 -1.e+09 ... -0.e+00 -1.e+09 -1.e+09]
      [-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -0.e+00 -1.e+09]
      [-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -1.e+09 -0.e+00]]], shape=(9, 64, 64), dtype=float32)
     [b'CN(C(=O)C(Cc1ccccc1Cl)C[NH+]1CCC2(CC1)OCCc1cc(F)sc12)C1CC1']], shape=(9, 1), dtype=string)
    [[b'' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'O' b'C'
      b'C' b'O' b'O' b'C' b'Cl' b'C' b'Cl' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'']
     [b'' b'C' b'N' b'C' b'C' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C'
      b'C' b'O' b'C' b'N' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'F' b'F' b'F'
      b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'']
     [b'' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
     [b'' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'O' b'O'
      b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'O'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'']
     [b'' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'C' b'S' b'O' b'C' b'C'
      b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'']
     [b'' b'C' b'C' b'O' b'N' b'C' b'N' b'N' b'C' b'S' b'N' b'O' b'O'
      b'S' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
     [b'' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'S' b'O' b'O' b'N' b'C'
      b'O' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'']
     [b'' b'N' b'C' b'C' b'C' b'Cl' b'C' b'C' b'C' b'C' b'C' b'C'
      b'N' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
      b'' b'' b'']
     [b'' b'C' b'N' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C'
      b'Cl' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C'
      b'F' b'S' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
      b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H']], shape=(9, 64), dtype=string)
    [[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
        1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
     [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
        0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(9, 1, 1, 64), dtype=float32)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 将 Inference_Dataset 初始化中的9个 SMILES 转化成下列数据:
      • x是 SMILES 的向量表示,以 dataset 定义的 str2num映射
      • adjoin_matrix 是分子图的邻接矩阵表示,有键相连为-0.e+00,没有键相连为-1.e+09
      • smiles 是输入的 SMILES
      • atom_list 是 SMILES 的原子列表,包括添加的超节点
      • mask 是 pad 标志,pad 的位置是1,没有pad 是0,因为最后一个 SMILES 的分子最长,其他都需要 pad 到它的长度,只有 mask 最后一个是最长的,没有 pad标志都是0
    model = PredictModel_test(num_layers=num_layers, d_model=d_model, dff=dff, num_heads=num_heads, vocab_size=vocab_size,dense_dropout=0.15)
    pred = model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
    • 1
    • 2
    • 3
    • 构造推理模型,加载 regression 训练好的参数


    class PredictModel_test(tf.keras.Model):
        def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
            super(PredictModel_test, self).__init__()
            self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
            self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
            self.dropout = tf.keras.layers.Dropout(dense_dropout)
            self.fc2 = tf.keras.layers.Dense(1)
        def call(self,x,adjoin_matrix,mask,training=False):
            x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
            x = x[:, 0, :]
            x = self.fc1(x)
            x = self.dropout(x, training=training)
            x = self.fc2(x)
            return x,att,x
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • encoder 部分是 Encoder_test,在 call 阶段输出不同


    class Encoder_test(tf.keras.Model):
        def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                     maximum_position_encoding, rate=0.1):
            super(Encoder_test, self).__init__()
            self.d_model = d_model
            self.num_layers = num_layers
            self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
            # self.pos_encoding = positional_encoding(maximum_position_encoding,
            #                                         self.d_model)
            self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
                               for _ in range(num_layers)]
            self.dropout = tf.keras.layers.Dropout(rate)
        def call(self, x, training, mask,adjoin_matrix):
            seq_len = tf.shape(x)[1]
            adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
            # adding embedding and position encoding.
            x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
            # x += self.pos_encoding[:, :seq_len, :]
            x = self.dropout(x, training=training)
            attention_weights_list = []
            xs = []
            for i in range(self.num_layers):
                x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
            return x,attention_weights_list,xs
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 每经过一个 EncoderLayer 层记录一次输出的 x 和注意力权重


    class PredictModel_test(tf.keras.Model):
        def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
            super(PredictModel_test, self).__init__()
            self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
            self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
            self.dropout = tf.keras.layers.Dropout(dense_dropout)
            self.fc2 = tf.keras.layers.Dense(1)
        def call(self,x,adjoin_matrix,mask,training=False):
            x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
            x = x[:, 0, :]
            x = self.fc1(x)
            x = self.dropout(x, training=training)
            x = self.fc2(x)
            return x,att,xs
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • encoder 部分是 Encoder_test,在 call 阶段输出不同


    class Encoder_test(tf.keras.Model):
        def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                     maximum_position_encoding, rate=0.1):
            super(Encoder_test, self).__init__()
            self.d_model = d_model
            self.num_layers = num_layers
            self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
            # self.pos_encoding = positional_encoding(maximum_position_encoding,
            #                                         self.d_model)
            self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
                               for _ in range(num_layers)]
            self.dropout = tf.keras.layers.Dropout(rate)
        def call(self, x, training, mask,adjoin_matrix):
            seq_len = tf.shape(x)[1]
            adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
            # adding embedding and position encoding.
            x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
            # x += self.pos_encoding[:, :seq_len, :]
            x = self.dropout(x, training=training)
            attention_weights_list = []
            xs = []
            for i in range(self.num_layers):
                x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
            return x,attention_weights_list,xs
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • x 先经过 Embedding,再进入6 个 EncoderLayer 层,每经过一个 EncoderLayer 层记录一次输出的 x 和注意力权重
    • 三者的 shape 如下
    x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
    seq = tf.cast(tf.math.equal(x, 0), tf.float32)
    mask = seq[:, tf.newaxis, tf.newaxis, :]
    x,atts,xs= model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    ((9, 1), (6, 9, 8, 64, 64), (6, 9, 64, 256))
    • 1
    • x 是经过 model 完整流程,最后预测得到的每个分子的 logD 矩阵,atts 和 xs 是经过 6 层 EncoderLayer 后每层收集到的注意力权重和 x,注意力权重的 shape 是 (9, 8, 64, 64),x 的 shape 是 (9, 64, 256)
    • 输入 model 的参数 shape 分别是 (9,64),(9,64,64),(9,1,1,64),x 表示的是 9 个长度为 64 的向量,向量的每个元素表示的是一个原子类型索引,经过 Embedding 后,变成了(9,64,256),主要变换是将每个原子变成了长度是 256 的向量
    • 输入 EncoderLayer 的参数 x,mask,adjoin_matrix 的 shape 分别是 (9,64,256),(9,1,1,64),(9,1,64,64)

    3.2.atts & xs

    class EncoderLayer(tf.keras.layers.Layer):
        def __init__(self, d_model, num_heads, dff, rate=0.1):
            super(EncoderLayer, self).__init__()
            self.mha = MultiHeadAttention(d_model, num_heads)
            self.ffn = point_wise_feed_forward_network(d_model, dff)
            self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.dropout1 = tf.keras.layers.Dropout(rate)
            self.dropout2 = tf.keras.layers.Dropout(rate)
        def call(self, x, training, mask,adjoin_matrix):
            attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix)  # (batch_size, input_seq_len, d_model)
            attn_output = self.dropout1(attn_output, training=training)
            out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)
            ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
            ffn_output = self.dropout2(ffn_output, training=training)
            out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)
            return out2,attention_weights
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 这里的 out2 和 attention_weights 就是列表中的单个元素
    • 进入多头注意力层得到输出


    class MultiHeadAttention(tf.keras.layers.Layer):
        def __init__(self, d_model, num_heads):
            super(MultiHeadAttention, self).__init__()
            self.num_heads = num_heads
            self.d_model = d_model
            assert d_model % self.num_heads == 0
            self.depth = d_model // self.num_heads
            self.wq = tf.keras.layers.Dense(d_model)
            self.wk = tf.keras.layers.Dense(d_model)
            self.wv = tf.keras.layers.Dense(d_model)
            self.dense = tf.keras.layers.Dense(d_model)
        def split_heads(self, x, batch_size):
            """Split the last dimension into (num_heads, depth).
            Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
            x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
            return tf.transpose(x, perm=[0, 2, 1, 3])
        def call(self, v, k, q, mask,adjoin_matrix):
            batch_size = tf.shape(q)[0]
            q = self.wq(q)  # (batch_size, seq_len, d_model)
            k = self.wk(k)  # (batch_size, seq_len, d_model)
            v = self.wv(v)  # (batch_size, seq_len, d_model)
            q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
            k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
            v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
            # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
            # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
            scaled_attention, attention_weights = scaled_dot_product_attention(
                q, k, v, mask,adjoin_matrix)
            scaled_attention = tf.transpose(scaled_attention,
                                            perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
            concat_attention = tf.reshape(scaled_attention,
                                          (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
            output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
            return output, attention_weights
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • (9,64,256) 的 x 分别作为 q,k,v 经过全连接层后 shape 仍然是 (9,64,256),然后得到 (9,8,64,64) 的注意力张量 attention_weights 和 (9,64,256) 的 output,之后经过全连接层 ffn 等层之后作为下一层 EncoderLayer 的 x


    i = 0
    smiles_plot = smiles[i].numpy().tolist()[0].decode()
    mol = Chem.MolFromSmiles(smiles_plot)
    num_atoms = mol.GetNumAtoms()
    attentions_plot = tf.concat([att[i:(i+1),:,:num_atoms+1,:num_atoms+1] for att in atts],axis=0)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • x 是预测的 logD,smiles_plot 是第 i 个 smiles 字符串

    • atts 的 shape 是 (6, 9, 8, 64, 64),att 是每一层 EncoderLayer 得到的注意力权重张量,att[i:(i+1),:,:num_atoms+1,:num_atoms+1] 索引的是第 i 个分子的注意力权重,而且用 num_atoms+1 将 pad 部分的注意力排除,这里表示的是第0个分子对应的8个注意力头得到的注意力张量。最后将6个注意力张量concat起来,最终得到的 shape 是 (6,8,20,20),这里20包括超节点的注意力

    • 输出结果如下:

    [[ 0.00611259]
     [ 0.10352743]
     [ 0.29321525]
     [-0.2280695 ]
     [ 0.03325798]
     [ 0.24569517]
     [ 0.2692453 ]], shape=(9, 1), dtype=float32)
    [[0.02703399 0.03738527 0.0341253  0.04261673 0.03024014 0.046207
      0.04024169 0.03412233]
     [0.00220853 0.01803984 0.01086219 0.10811627 0.0314402  0.20126604
      0.03202052 0.01011109]
     [0.01053316 0.02646094 0.02683227 0.05374964 0.00403768 0.01397272
      0.03217852 0.00420805]
     [0.01578051 0.00886862 0.00608971 0.17516439 0.0039229  0.01038222
      0.0165774  0.04300734]
     [0.02164022 0.05754587 0.01020102 0.01972281 0.04598003 0.0355869
      0.0741061  0.03371757]
     [0.02877988 0.02249552 0.03075971 0.04478414 0.04234883 0.03901222
      0.05866967 0.02472056]]
    [['O13', '0.07'], ['C12', '0.05'], ['O14', '0.05'], ['C11', '0.05'], ['O5', '0.05'], ['O10', '0.04'], ['C3', '0.03'], ['C2', '0.03'], ['C9', '0.03'], ['Cl16', '0.03'], ['C0', '0.02'], ['C8', '0.02'], ['C4', '0.02'], ['C1', '0.02'], ['C7', '0.02'], ['Cl18', '0.02'], ['C15', '0.02'], ['C6', '0.02'], ['C17', '0.02']]
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23


    def plot_weights(smiles,attention_plot,max=5):
        mol = Chem.MolFromSmiles(smiles_plot)
        mol = Chem.RemoveHs(mol)
    #     mol = Chem.AddHs(mol)
        num_atoms = mol.GetNumAtoms()
        atoms = []
        for i in range(num_atoms):
            atom = mol.GetAtomWithIdx(i)
    #   att = tf.reduce_mean(tf.reduce_mean(attentions_plot[:,:,0,:],axis=0),axis=0)[1:].numpy()
        att = tf.reduce_mean(tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0),axis=0)[1:].numpy()  #num_layers * num_heads * num_atoms * num_atoms
        indices = (-att).argsort()
        highlight = indices.tolist()
        print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])
        drawer = rdMolDraw2D.MolDraw2DSVG(800,600)
        opts = drawer.drawOptions()
        drawer.drawOptions().updateAtomPalette({k: (0, 0, 0) for k in DrawingOptions.elemDict.keys()})
    #     for i in range(mol.GetNumAtoms()):
    #         opts.atomLabels[i] = mol.GetAtomWithIdx(i).GetSymbol()
        colors = {}
        for i,h in enumerate(highlight):
            colors[h] = (1,
        drawer.DrawMolecule(mol,highlightAtoms = highlight,highlightAtomColors=colors,highlightBonds=[])
        svg = drawer.GetDrawingText().replace('svg:','')
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0) 求和表示将后3层的8个头的注意力值求和,得到 (8,20) 的矩阵,再求和得到20个原子的注意力值,取[1:]除去添加的超原子,最终得到 (19,) 的向量表示每个原子的注意力,输出如下
    • 1
    <tf.Tensor: shape=(19,), dtype=float32, numpy=
    array([0.02398169, 0.02298377, 0.03124303, 0.03180713, 0.0232711 ,
           0.04853497, 0.0192481 , 0.02276288, 0.02380962, 0.02621273,
           0.03749018, 0.04908869, 0.05493892, 0.06948597, 0.04930567,
           0.02019732, 0.0252701 , 0.01770476, 0.02110688], dtype=float32)>
    • 1
    • 2
    • 3
    • 4
    • 5
    • 将注意力值排序后每个原子在分子属性中的贡献(注意力)排序,输出如下:
    indices = (-att).argsort()
    highlight = indices.tolist()
    • 1
    • 2
    • 3
    [13, 12, 14, 11, 5, 10, 3, 2, 9, 16, 0, 8, 4, 1, 7, 18, 15, 6, 17]
    • 1
    • 这里的输出意味着对预测本分子 logD 贡献最大原子的索引是13,贡献第二大的原子索引是12,以此类推
    print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])
    • 1
    • 当i为0表示取第1个原子,indices[i] 得到贡献序号,atoms[indices[i]] 得到具体的原子,(‘%.2f’%att[indices[i]]) 取这个原子的贡献值,这里即 [‘O13’, ‘0.07’],意味着第13个原子O对预测 logD 的贡献是 0.07,将分子可视化且将贡献值也就是注意力高亮就得到了图片展示的结果
    • 剩下的示例运行与上面的第一个分子基本一致
  • 相关阅读:
    树莓派 3b+ 学习
    Excel - 如何给单元格加上下拉框
  • 原文地址:https://blog.csdn.net/weixin_52812620/article/details/126372513