教程来源TorchDrug开源
反合成是药物发现的一项基本任务。给定一个目标分子,反合成的目标是确定一组可以产生目标的反应物。
在这个例子中,我们将展示如何使用G2Gs框架预测逆合成。G2Gs首先识别反应中心,即产物中产生的键。根据反应中心,产物被分解成几个合成子,每个合成子被转化为一个反应物。
我们使用标准USPTO50k数据集。该数据集包含50k分子及其合成途径。首先,让我们下载并加载数据集。这可能需要一段时间。有两种模式来加载数据集。reaction模式将数据集加载为(reactants
, product
)对,用于中心识别。synthon模式将数据集作为(reactant
,synthon
)对加载,用于synthon
完成。
from torchdrug import data, datasets, utils
reaction_dataset = datasets.USPTO50k("~/molecule-datasets/",
atom_feature="center_identification",
kekulize=True)
synthon_dataset = datasets.USPTO50k("~/molecule-datasets/", as_synthon=True,
atom_feature="synthon_completion",
kekulize=True)
然后我们将数据集中的一些样本可视化。对于反应数据集,我们可以使用connected components()将反应物图和生成物图拆分为单个分子。注意USPTO50k忽略了所有非目标产品,所以右边只有一个产品。
from torchdrug.utils import plot
for i in range(2):
sample = reaction_dataset[i]
reactant, product = sample["graph"]
reactants = reactant.connected_components()[0]
products = product.connected_components()[0]
plot.reaction(reactants, products)
下面是synthon数据集中对应的示例。
for i in range(3):
sample = synthon_dataset[i]
reactant, synthon = sample["graph"]
plot.reaction([reactant], [synthon])
为了确保两个数据集使用相同的split,我们可以在调用split()
之前设置随机种子。
import torch
torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()
现在我们定义我们的模型。我们使用一个关系图卷积网络(RGCN)作为我们的表示模型,并包装它来完成中心识别任务。注意,这里也可以使用其他图表示学习模型。
from torchdrug import core, models, tasks
reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=reaction_dataset.num_bond_type,
concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,
feature=("graph", "atom", "bond"))
reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3)
reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid,
reaction_test, reaction_optimizer,
gpus=[0], batch_size=128)
reaction_solver.train(num_epoch=50)
reaction_solver.evaluate("valid")
reaction_solver.save("g2gs_reaction_model.pth")
验证集上的计算结果可能如下所示
accuracy: 0.836367
我们可以从我们的模型中展示一些预测。为了多样性,我们收集了4种不同反应类型的样品。
batch = []
reaction_set = set()
for sample in reaction_valid:
if sample["reaction"] not in reaction_set:
reaction_set.add(sample["reaction"])
batch.append(sample)
if len(batch) == 4:
break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
result = reaction_task.predict_synthon(batch)
下面的代码可视化了基本事实以及我们对样本的预测。我们用蓝色代表基本事实,红色代表错误的预测,紫色代表正确的预测。
def atoms_and_bonds(molecule, reaction_center):
is_reaction_atom = (molecule.atom_map > 0) & \
(molecule.atom_map.unsqueeze(-1) == \
reaction_center.unsqueeze(0)).any(dim=-1)
node_in, node_out = molecule.edge_list.t()[:2]
edge_map = molecule.atom_map[molecule.edge_list[:, :2]]
is_reaction_bond = (edge_map > 0).all(dim=-1) & \
(edge_map == reaction_center.unsqueeze(0)).all(dim=-1)
atoms = is_reaction_atom.nonzero().flatten().tolist()
bonds = is_reaction_bond[node_in < node_out].nonzero().flatten().tolist()
return atoms, bonds
products = batch["graph"][1]
reaction_centers = result["reaction_center"]
for i, product in enumerate(products):
true_atoms, true_bonds = atoms_and_bonds(product, product.reaction_center)
true_atoms, true_bonds = set(true_atoms), set(true_bonds)
pred_atoms, pred_bonds = atoms_and_bonds(product, reaction_centers[i])
pred_atoms, pred_bonds = set(pred_atoms), set(pred_bonds)
overlap_atoms = true_atoms.intersection(pred_atoms)
overlap_bonds = true_bonds.intersection(pred_bonds)
atoms = true_atoms.union(pred_atoms)
bonds = true_bonds.union(pred_bonds)
red = (1, 0.5, 0.5)
blue = (0.5, 0.5, 1)
purple = (1, 0.5, 1)
atom_colors = {}
bond_colors = {}
for atom in atoms:
if atom in overlap_atoms:
atom_colors[atom] = purple
elif atom in pred_atoms:
atom_colors[atom] = red
else:
atom_colors[atom] = blue
for bond in bonds:
if bond in overlap_bonds:
bond_colors[bond] = purple
elif bond in pred_bonds:
bond_colors[bond] = red
else:
bond_colors[bond] = blue
plot.highlight(product, atoms, bonds, atom_colors, bond_colors)
类似地,我们在synthon数据集上训练synthon完成模型。
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=synthon_dataset.num_bond_type,
concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))
synthon_optimizer = torch.optim.Adam(synthon_task.parameters(), lr=1e-3)
synthon_solver = core.Engine(synthon_task, synthon_train, synthon_valid,
synthon_test, synthon_optimizer,
gpus=[0], batch_size=128)
synthon_solver.train(num_epoch=10)
synthon_solver.evaluate("valid")
synthon_solver.save("g2gs_synthon_model.pth")
我们可以得到一些结果
bond accuracy: 0.983013
node in accuracy: 0.967535
node out accuracy: 0.892999
stop accuracy: 0.929348
total accuracy: 0.844374
然后,我们执行束搜索,以产生候选反应物。
batch = []
reaction_set = set()
for sample in synthon_valid:
if sample["reaction"] not in reaction_set:
reaction_set.add(sample["reaction"])
batch.append(sample)
if len(batch) == 4:
break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
reactants, synthons = batch["graph"]
reactants = reactants.ion_to_molecule()
predictions = synthon_task.predict_reactant(batch, num_beam=10, max_prediction=5)
synthon_id = -1
i = 0
titles = []
graphs = []
for prediction in predictions:
if synthon_id != prediction.synthon_id:
synthon_id = prediction.synthon_id.item()
i = 0
graphs.append(reactants[synthon_id])
titles.append("Truth %d" % synthon_id)
i += 1
graphs.append(prediction)
if reactants[synthon_id] == prediction:
titles.append("Prediction %d-%d, Correct!" % (synthon_id, i))
else:
titles.append("Prediction %d-%d" % (synthon_id, i))
# reset attributes so that pack can work properly
mols = [graph.to_molecule() for graph in graphs]
graphs = data.PackedMolecule.from_molecule(mols)
graphs.visualize(titles, save_file="uspto50k_synthon_valid.png", num_col=6)
给定训练过的模型,我们可以将它们组合成一个端点管道进行逆向合成。这是通过将两个子任务包裹在一个逆合成任务中来完成的。
注意,如果您从未声明reaction_task
和synthon_task
的求解器,那么在将它们组合到管道中之前,您需要手动调用它们的preprocess()
方法。
# reaction_task.preprocess(reaction_train, None, None)
# synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
num_synthon_beam=5, max_prediction=10)
管道将对来自两个子任务的预测之间的所有可能组合执行波束搜索。为了演示,我们使用一个较小的光束尺寸,并且只对验证集的子集进行评估。注意,如果我们给光束搜索更多的预算,结果会更好。
from torch.utils import data as torch_data
lengths = [len(reaction_valid) // 10,
len(reaction_valid) - len(reaction_valid) // 10]
reaction_valid_small = torch_data.random_split(reaction_valid, lengths)[0]
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, reaction_train, reaction_valid_small, reaction_test,
optimizer, gpus=[0], batch_size=32)
要加载两个子任务的参数,我们只需load_optimizer
。注意负载优化器应该设置为False
以避免冲突。
solver.load("g2gs_reaction_model.pth", load_optimizer=False)
solver.load("g2gs_synthon_model.pth", load_optimizer=False)
solver.evaluate("valid")
反合成的准确性可能接近于以下
top-1 accuracy: 0.47541
top-3 accuracy: 0.741803
top-5 accuracy: 0.827869
top-10 accuracy: 0.879098
以下是验证集中样本的前1个预测
batch = []
reaction_set = set()
for sample in reaction_valid:
if sample["reaction"] not in reaction_set:
reaction_set.add(sample["reaction"])
batch.append(sample)
if len(batch) == 4:
break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)
products = batch["graph"][1]
top1_index = num_prediction.cumsum(0) - num_prediction
for i in range(len(products)):
reactant = predictions[top1_index[i]].connected_components()[0]
product = products[i].connected_components()[0]
plot.reaction(reactant, product)