• TorchDrug教程--预训练的分子表示


    TorchDrug教程–预训练的分子表示

    教程来源TorchDrug开源

    目录

    在许多药物发现任务中,收集标记数据在时间和金钱上都是昂贵的。作为一种解决方案,引入了自监督预训练来从大量未标记的数据中学习分子表示。

    在本教程中,我们将演示如何在分子上预训练图神经网络,以及如何在下游任务上微调模型。

    自我监督预训练

    预训练是在图神经网络中进行图级属性预测的一种有效的迁移学习方法。在这里,我们专注于通过不同的自我监督策略预训练GNNs。这些方法通常基于分子的结构信息构建无监督损失函数。

    为了说明原因,我们在本教程中只使用ClinTox数据集,它比标准的预训练数据集要小得多。

    Infograph

    InfoGraph (IG)建议最大化图级和节点级表示之间的互信息。它通过区分节点图对是来自单个图还是来自两个不同的图来学习模型。下图展示了InfoGraph的高级概念。

    我们使用GIN作为我们的图形表示模型,并用InfoGraph包装它。

    import torch
    from torch import nn
    from torch.utils import data as torch_data
    
    from torchdrug import core, datasets, tasks, models
    
    dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
                               bond_feature="pretrain")
    
    gin_model = models.GIN(input_dim=dataset.node_feature_dim,
                           hidden_dims=[300, 300, 300, 300, 300],
                           edge_input_dim=dataset.edge_feature_dim,
                           batch_norm=True, readout="mean")
    model = models.InfoGraph(gin_model, separate_model=False)
    
    task = tasks.Unsupervised(model)
    optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
    solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)
    
    solver.train(num_epoch=100)
    solver.save("clintox_gin_infograph.pth")
    

    经过训练,表示的相互信息可能接近

    average graph-node mutual information: 1.30658
    

    Attribute Masking

    属性masking的目的是通过学习分布在图结构上的节点/边属性的规律来获取领域知识。高层次的思想是通过随机掩盖的节点特征来预测分子图中的原子类型。

    同样,我们使用GIN作为我们的图表示模型。

    import torch
    from torch import nn, optim
    from torch.utils import data as torch_data
    
    from torchdrug import core, datasets, tasks, models
    
    dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
                               bond_feature="pretrain")
    
    model = models.GIN(input_dim=dataset.node_feature_dim,
                       hidden_dims=[300, 300, 300, 300, 300],
                       edge_input_dim=dataset.edge_feature_dim,
                       batch_norm=True, readout="mean")
    task = tasks.AttributeMasking(model, mask_rate=0.15)
    
    optimizer = optim.Adam(task.parameters(), lr=1e-3)
    solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)
    
    solver.train(num_epoch=100)
    solver.save("clintox_gin_attributemasking.pth")
    

    通常,训练精度和交叉熵看起来如下所示。

    average accuracy: 0.920366
    average cross entropy: 0.22998
    

    除了InfoGraph和Attribute Masking, gnn的预训练还有一些其他的策略。有关详细信息,请参阅下面的文档。

    InfoGraph, AttributeMasking, EdgePrediction, ContextPrediction

    关于标记数据集的Finetune

    当GNN预训练完成后,我们可以在下游任务上对预训练的GNN模型进行微调。这里我们使用BACE数据集进行说明,该数据集包含1513个具有结合亲和力的人β-分泌酶1(BACE-1)抑制剂分子。

    首先,我们下载BACE数据集,并将其分为训练集、验证集和测试集。注意,我们需要将数据集中的节点和边缘特征设置为预训练,以使其与预训练的模型兼容。

    dataset = datasets.BACE("~/molecule-datasets/",
                            atom_feature="pretrain", bond_feature="pretrain")
    lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
    lengths += [len(dataset) - sum(lengths)]
    train_set, valid_set, test_set = data.ordered_scaffold_split(dataset, lengths)
    

    然后,我们定义与预训练阶段相同的模型,并为我们的下游任务设置优化器和求解器。这里唯一的区别是我们使用PropertyPrediction任务来支持监督学习。

    model = models.GIN(input_dim=dataset.node_feature_dim,
                    hidden_dims=[300, 300, 300, 300, 300],
                    edge_input_dim=dataset.edge_feature_dim,
                    batch_norm=True, readout="mean")
    task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                    criterion="bce", metric=("auprc", "auroc"))
    
    optimizer = optim.Adam(task.parameters(), lr=1e-3)
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                         gpus=[0], batch_size=256)
    

    现在我们可以加载预训练的模型,并在下游数据集上对其进行微调。

    checkpoint = torch.load("clintox_gin_attributemasking.pth")["model"]
    task.load_state_dict(checkpoint, strict=False)
    
    solver.train(num_epoch=100)
    solver.evaluate("valid")
    

    一旦模型训练好了,我们就在验证集上评估它。结果可能类似于下面的情况。

    auprc [Class]: 0.921956
    auroc [Class]: 0.663004
    
  • 相关阅读:
    『现学现忘』Docker相关概念 — 2、云计算的服务模式
    Fliki AI:让视频创作更简单、更高效
    OB_GINS_day3
    OAK相机:自动或手动设置相机参数
    论文回顾:Playful Palette: An Interactive Parametric Color Mixer for Artists
    红帽认证笔记2
    apt安装yum
    2023-09-21 LeetCode每日一题(收集树中金币)
    app使用
    C++ Qt开发:Charts绘制各类图表详解
  • 原文地址:https://blog.csdn.net/weixin_42486623/article/details/127039059