教程来源TorchDrug开源

在许多药物发现任务中,收集标记数据在时间和金钱上都是昂贵的。作为一种解决方案,引入了自监督预训练来从大量未标记的数据中学习分子表示。
在本教程中,我们将演示如何在分子上预训练图神经网络,以及如何在下游任务上微调模型。
预训练是在图神经网络中进行图级属性预测的一种有效的迁移学习方法。在这里,我们专注于通过不同的自我监督策略预训练GNNs。这些方法通常基于分子的结构信息构建无监督损失函数。
为了说明原因,我们在本教程中只使用ClinTox数据集,它比标准的预训练数据集要小得多。
InfoGraph (IG)建议最大化图级和节点级表示之间的互信息。它通过区分节点图对是来自单个图还是来自两个不同的图来学习模型。下图展示了InfoGraph的高级概念。

我们使用GIN作为我们的图形表示模型,并用InfoGraph包装它。
import torch
from torch import nn
from torch.utils import data as torch_data
from torchdrug import core, datasets, tasks, models
dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
bond_feature="pretrain")
gin_model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300, 300, 300, 300],
edge_input_dim=dataset.edge_feature_dim,
batch_norm=True, readout="mean")
model = models.InfoGraph(gin_model, separate_model=False)
task = tasks.Unsupervised(model)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)
solver.train(num_epoch=100)
solver.save("clintox_gin_infograph.pth")
经过训练,表示的相互信息可能接近
average graph-node mutual information: 1.30658
属性masking的目的是通过学习分布在图结构上的节点/边属性的规律来获取领域知识。高层次的思想是通过随机掩盖的节点特征来预测分子图中的原子类型。

同样,我们使用GIN作为我们的图表示模型。
import torch
from torch import nn, optim
from torch.utils import data as torch_data
from torchdrug import core, datasets, tasks, models
dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
bond_feature="pretrain")
model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300, 300, 300, 300],
edge_input_dim=dataset.edge_feature_dim,
batch_norm=True, readout="mean")
task = tasks.AttributeMasking(model, mask_rate=0.15)
optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)
solver.train(num_epoch=100)
solver.save("clintox_gin_attributemasking.pth")
通常,训练精度和交叉熵看起来如下所示。
average accuracy: 0.920366
average cross entropy: 0.22998
除了InfoGraph和Attribute Masking, gnn的预训练还有一些其他的策略。有关详细信息,请参阅下面的文档。
InfoGraph, AttributeMasking, EdgePrediction, ContextPrediction
当GNN预训练完成后,我们可以在下游任务上对预训练的GNN模型进行微调。这里我们使用BACE数据集进行说明,该数据集包含1513个具有结合亲和力的人β-分泌酶1(BACE-1)抑制剂分子。
首先,我们下载BACE数据集,并将其分为训练集、验证集和测试集。注意,我们需要将数据集中的节点和边缘特征设置为预训练,以使其与预训练的模型兼容。
dataset = datasets.BACE("~/molecule-datasets/",
atom_feature="pretrain", bond_feature="pretrain")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = data.ordered_scaffold_split(dataset, lengths)
然后,我们定义与预训练阶段相同的模型,并为我们的下游任务设置优化器和求解器。这里唯一的区别是我们使用PropertyPrediction任务来支持监督学习。
model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300, 300, 300, 300],
edge_input_dim=dataset.edge_feature_dim,
batch_norm=True, readout="mean")
task = tasks.PropertyPrediction(model, task=dataset.tasks,
criterion="bce", metric=("auprc", "auroc"))
optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=256)
现在我们可以加载预训练的模型,并在下游数据集上对其进行微调。
checkpoint = torch.load("clintox_gin_attributemasking.pth")["model"]
task.load_state_dict(checkpoint, strict=False)
solver.train(num_epoch=100)
solver.evaluate("valid")
一旦模型训练好了,我们就在验证集上评估它。结果可能类似于下面的情况。
auprc [Class]: 0.921956
auroc [Class]: 0.663004