基于分子图的 BERT 模型,原文:MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction,原文解析:MG-BERT | 利用 无监督 原子表示学习 预测分子性质 | 在分子图上应用BERT | GNN | 无监督学习(掩蔽原子预训练) | attention,代码:Molecular-graph-BERT。本文在前两篇分析的基础上看 attention__visualize 部分
medium = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'medium_weights','addH':True}
rch = medium ## small 3 4 128 medium: 6 6 256 large: 12 8 516
trained_epoch = 8
num_layers = arch['num_layers']
num_heads = arch['num_heads']
d_model = arch['d_model']
addH = arch['addH']
dff = d_model * 2
vocab_size = 17
dropout_rate = 0.1
seed = 7
np.random.seed(seed=seed)
tf.random.set_seed(seed=seed)
task = 'logD'
df = pd.read_csv('data/reg/logD.txt',sep='\t')
sml_list = df['SMILES'].tolist()
inference_dataset = Inference_Dataset(['C=C(CC)C(=O)c1ccc(OCC(=O)[O-])c(Cl)c1Cl',
'CN(Cc1c(C(=O)NC2CCCC(O)C2)noc1-c1ccc(C(F)(F)F)cc1)C1CCOC1',
'CC(=O)Nc1ccc(O)cc1',
'CC(=O)CC(c1ccc([N+](=O)[O-])cc1)c1c(O)c2ccccc2oc1=O',
'Cc1cccc(NC(=S)Oc2ccc3ccccc3c2)c1',
'CC(=O)Nc1nnc(S(N)(=O)=O)s1',
'CC(=O)c1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1',
'N#Cc1c(Cl)c2ccccc2n2c1nc1ccccc12',
'CN(C(=O)C(Cc1ccccc1Cl)C[NH+]1CCC2(CC1)OCCc1cc(F)sc12)C1CC1'],addH=addH).get_data()
class Inference_Dataset(object):
def __init__(self,sml_list,max_len=100,addH=True):
self.vocab = str2num
self.devocab = num2str
self.sml_list = [i for i in sml_list if len(i)<max_len]
self.addH = addH
def get_data(self):
self.dataset = tf.data.Dataset.from_tensor_slices((self.sml_list,))
self.dataset = self.dataset.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]),tf.TensorShape([None]))).cache().prefetch(20)
return self.dataset
def numerical_smiles(self, smiles):
smiles_origin = smiles
smiles = smiles.numpy().decode()
atoms_list, adjoin_matrix = smiles2adjoin(smiles,explicit_hydrogens=self.addH)
atoms_list = ['' ] + atoms_list
nums_list = [str2num.get(i,str2num['' ]) for i in atoms_list]
temp = np.ones((len(nums_list),len(nums_list)))
temp[1:,1:] = adjoin_matrix
adjoin_matrix = (1-temp)*(-1e9)
x = np.array(nums_list).astype('int64')
return x, adjoin_matrix,[smiles], atoms_list
def tf_numerical_smiles(self, smiles):
x,adjoin_matrix,smiles,atom_list = tf.py_function(self.numerical_smiles, [smiles], [tf.int64, tf.float32,tf.string, tf.string])
x.set_shape([None])
adjoin_matrix.set_shape([None,None])
smiles.set_shape([1])
atom_list.set_shape([None])
return x, adjoin_matrix,smiles,atom_list
x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
print(x, adjoin_matrix, smiles ,atom_list,mask)
tf.Tensor(
[[16 2 2 2 2 2 4 2 2 2 2 4 2 2 4 4 2 7 2 7 1 1 1 1
1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 3 2 2 2 2 4 3 2 2 2 2 2 4 2 3 4 2 2 2 2 2 2
5 5 5 2 2 2 2 2 4 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]
[16 2 2 4 3 2 2 2 2 4 2 2 1 1 1 1 1 1 1 1 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 2 2 2 2 2 2 3 4 4 2 2 2 2 4 2 2 2 2 2 2
4 2 4 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 2 2 2 2 3 2 6 4 2 2 2 2 2 2 2 2 2 2 2 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 3 2 3 3 2 6 3 4 4 6 1 1 1 1 1 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 2 2 2 2 6 4 4 3 2 4 3 2 2 2 2 2 2 2 2 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 3 2 2 2 7 2 2 2 2 2 2 3 2 3 2 2 2 2 2 2 1 1 1
1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 3 2 4 2 2 2 2 2 2 2 2 7 2 3 2 2 2 2 2 4 2 2
2 2 2 5 6 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]], shape=(9, 64), dtype=int64)
tf.Tensor(
[[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 0.e+00 0.e+00]
...
[-0.e+00 -1.e+09 -1.e+09 ... -0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
...
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00 -0.e+00 -0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
...
[-0.e+00 -1.e+09 -1.e+09 ... -0.e+00 -1.e+09 -1.e+09]
[-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -0.e+00 -1.e+09]
[-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -1.e+09 -0.e+00]]], shape=(9, 64, 64), dtype=float32)
tf.Tensor(
[[b'C=C(CC)C(=O)c1ccc(OCC(=O)[O-])c(Cl)c1Cl']
[b'CN(Cc1c(C(=O)NC2CCCC(O)C2)noc1-c1ccc(C(F)(F)F)cc1)C1CCOC1']
[b'CC(=O)Nc1ccc(O)cc1']
[b'CC(=O)CC(c1ccc([N+](=O)[O-])cc1)c1c(O)c2ccccc2oc1=O']
[b'Cc1cccc(NC(=S)Oc2ccc3ccccc3c2)c1']
[b'CC(=O)Nc1nnc(S(N)(=O)=O)s1']
[b'CC(=O)c1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1']
[b'N#Cc1c(Cl)c2ccccc2n2c1nc1ccccc12']
[b'CN(C(=O)C(Cc1ccccc1Cl)C[NH+]1CCC2(CC1)OCCc1cc(F)sc12)C1CC1']], shape=(9, 1), dtype=string)
tf.Tensor(
[[b'' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'O' b'C'
b'C' b'O' b'O' b'C' b'Cl' b'C' b'Cl' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'']
[b'' b'C' b'N' b'C' b'C' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C'
b'C' b'O' b'C' b'N' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'F' b'F' b'F'
b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'']
[b'' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
[b'' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'O' b'O'
b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'O'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'']
[b'' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'C' b'S' b'O' b'C' b'C'
b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'']
[b'' b'C' b'C' b'O' b'N' b'C' b'N' b'N' b'C' b'S' b'N' b'O' b'O'
b'S' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
[b'' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'S' b'O' b'O' b'N' b'C'
b'O' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'']
[b'' b'N' b'C' b'C' b'C' b'Cl' b'C' b'C' b'C' b'C' b'C' b'C'
b'N' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'']
[b'' b'C' b'N' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C'
b'Cl' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C'
b'F' b'S' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H']], shape=(9, 64), dtype=string)
tf.Tensor(
[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(9, 1, 1, 64), dtype=float32)
model = PredictModel_test(num_layers=num_layers, d_model=d_model, dff=dff, num_heads=num_heads, vocab_size=vocab_size,dense_dropout=0.15)
pred = model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
model.load_weights('regression_weights/logD.h5')
class PredictModel_test(tf.keras.Model):
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
super(PredictModel_test, self).__init__()
self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
self.dropout = tf.keras.layers.Dropout(dense_dropout)
self.fc2 = tf.keras.layers.Dense(1)
def call(self,x,adjoin_matrix,mask,training=False):
x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
x = x[:, 0, :]
x = self.fc1(x)
x = self.dropout(x, training=training)
x = self.fc2(x)
return x,att,x
class Encoder_test(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder_test, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
# self.pos_encoding = positional_encoding(maximum_position_encoding,
# self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
seq_len = tf.shape(x)[1]
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
# x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
attention_weights_list = []
xs = []
for i in range(self.num_layers):
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
attention_weights_list.append(attention_weights)
xs.append(x)
return x,attention_weights_list,xs
class PredictModel_test(tf.keras.Model):
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
super(PredictModel_test, self).__init__()
self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
self.dropout = tf.keras.layers.Dropout(dense_dropout)
self.fc2 = tf.keras.layers.Dense(1)
def call(self,x,adjoin_matrix,mask,training=False):
x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
x = x[:, 0, :]
x = self.fc1(x)
x = self.dropout(x, training=training)
x = self.fc2(x)
return x,att,xs
class Encoder_test(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder_test, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
# self.pos_encoding = positional_encoding(maximum_position_encoding,
# self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
seq_len = tf.shape(x)[1]
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
# x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
attention_weights_list = []
xs = []
for i in range(self.num_layers):
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
attention_weights_list.append(attention_weights)
xs.append(x)
return x,attention_weights_list,xs
x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
x,atts,xs= model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
np.asarray(x).shape,np.asarray(atts).shape,np.asarray(xs).shape
((9, 1), (6, 9, 8, 64, 64), (6, 9, 64, 256))
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
return out2,attention_weights
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask,adjoin_matrix):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask,adjoin_matrix)
scaled_attention = tf.transpose(scaled_attention,
perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
i = 0
print(x)
smiles_plot = smiles[i].numpy().tolist()[0].decode()
mol = Chem.MolFromSmiles(smiles_plot)
num_atoms = mol.GetNumAtoms()
attentions_plot = tf.concat([att[i:(i+1),:,:num_atoms+1,:num_atoms+1] for att in atts],axis=0)
plot_weights(smiles_plot,attentions_plot)
x 是预测的 logD,smiles_plot 是第 i 个 smiles 字符串
atts 的 shape 是 (6, 9, 8, 64, 64),att 是每一层 EncoderLayer 得到的注意力权重张量,att[i:(i+1),:,:num_atoms+1,:num_atoms+1] 索引的是第 i 个分子的注意力权重,而且用 num_atoms+1 将 pad 部分的注意力排除,这里表示的是第0个分子对应的8个注意力头得到的注意力张量。最后将6个注意力张量concat起来,最终得到的 shape 是 (6,8,20,20),这里20包括超节点的注意力
输出结果如下:
tf.Tensor(
[[ 0.00611259]
[ 0.10352743]
[-0.09410169]
[-0.01148836]
[ 0.29321525]
[-0.2280695 ]
[ 0.03325798]
[ 0.24569517]
[ 0.2692453 ]], shape=(9, 1), dtype=float32)
[[0.02703399 0.03738527 0.0341253 0.04261673 0.03024014 0.046207
0.04024169 0.03412233]
[0.00220853 0.01803984 0.01086219 0.10811627 0.0314402 0.20126604
0.03202052 0.01011109]
[0.01053316 0.02646094 0.02683227 0.05374964 0.00403768 0.01397272
0.03217852 0.00420805]
[0.01578051 0.00886862 0.00608971 0.17516439 0.0039229 0.01038222
0.0165774 0.04300734]
[0.02164022 0.05754587 0.01020102 0.01972281 0.04598003 0.0355869
0.0741061 0.03371757]
[0.02877988 0.02249552 0.03075971 0.04478414 0.04234883 0.03901222
0.05866967 0.02472056]]
[['O13', '0.07'], ['C12', '0.05'], ['O14', '0.05'], ['C11', '0.05'], ['O5', '0.05'], ['O10', '0.04'], ['C3', '0.03'], ['C2', '0.03'], ['C9', '0.03'], ['Cl16', '0.03'], ['C0', '0.02'], ['C8', '0.02'], ['C4', '0.02'], ['C1', '0.02'], ['C7', '0.02'], ['Cl18', '0.02'], ['C15', '0.02'], ['C6', '0.02'], ['C17', '0.02']]
def plot_weights(smiles,attention_plot,max=5):
mol = Chem.MolFromSmiles(smiles_plot)
mol = Chem.RemoveHs(mol)
# mol = Chem.AddHs(mol)
num_atoms = mol.GetNumAtoms()
atoms = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atoms.append(atom.GetSymbol()+str(i))
# att = tf.reduce_mean(tf.reduce_mean(attentions_plot[:,:,0,:],axis=0),axis=0)[1:].numpy()
att = tf.reduce_mean(tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0),axis=0)[1:].numpy() #num_layers * num_heads * num_atoms * num_atoms
print(attentions_plot[:,:,0,0].numpy())
indices = (-att).argsort()
highlight = indices.tolist()
print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])
drawer = rdMolDraw2D.MolDraw2DSVG(800,600)
opts = drawer.drawOptions()
drawer.drawOptions().updateAtomPalette({k: (0, 0, 0) for k in DrawingOptions.elemDict.keys()})
# for i in range(mol.GetNumAtoms()):
# opts.atomLabels[i] = mol.GetAtomWithIdx(i).GetSymbol()
colors = {}
for i,h in enumerate(highlight):
colors[h] = (1,
1-1*(att[h]-att[highlight[-1]])/(att[highlight[0]]-att[highlight[-1]]),
1-1*(att[h]-att[highlight[-1]])/(att[highlight[0]]-att[highlight[-1]]))
drawer.DrawMolecule(mol,highlightAtoms = highlight,highlightAtomColors=colors,highlightBonds=[])
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace('svg:','')
display(SVG(svg))
tf.reduce_mean(tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0),axis=0)[1:]
<tf.Tensor: shape=(19,), dtype=float32, numpy=
array([0.02398169, 0.02298377, 0.03124303, 0.03180713, 0.0232711 ,
0.04853497, 0.0192481 , 0.02276288, 0.02380962, 0.02621273,
0.03749018, 0.04908869, 0.05493892, 0.06948597, 0.04930567,
0.02019732, 0.0252701 , 0.01770476, 0.02110688], dtype=float32)>
indices = (-att).argsort()
highlight = indices.tolist()
highlight
[13, 12, 14, 11, 5, 10, 3, 2, 9, 16, 0, 8, 4, 1, 7, 18, 15, 6, 17]
print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])