• torch.hub 记录


    Facebook官方博客表示,PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块,包含预训练模型库。并且,PyTorch Hub还支持Colab,能与论文代码结合网站Papers With Code集成,用于更广泛的研究。发布首日已有18个模型“入驻”,获得英伟达官方力挺。而且Facebook还鼓励论文发布者把自己的模型发布到这里来,让PyTorch Hub越来越强大。

    PyTorch Hub中提供的模型也支持Colab。进入每个模型的介绍页面后,你不仅可以看到GitHub代码页的入口,甚至可以一键进入Colab运行模型Demo。为了调用各种经典机器学习模型,今后你不必重复造轮子了。刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉、自然语言处理领域的诸多经典模型的聚合中心, 让你调用起来更方便。有多方便?图灵奖得主Yann LeCun强烈推荐,无论是ResNet、BERT、GPT、VGG、PGAN还是MobileNet等经典模型,只需输入一行代码,就能实现一键调用。
    PyTorch Hub的使用简单到不能再简单,不需要下载模型,只用了一个torch.hub.load()就完成了对图像分类模型AlexNet的调用。

    import torch
    model = torch.hub.load('pytorch/vision', 'alexnet', pretrained=True)
    model.eval()
    
    • 1
    • 2
    • 3

    PyTorch Hub允许用户对已发布的模型执行以下操作:
    1、查询可用的模型;
    2、加载模型;
    3、查询模型中可用的方法。

    下面让我们来看看每个应用的实例。
    1、查询可用的模型
    用户可以使用torch.hub.list()这个API列出repo中所有可用的入口点。
    比如你想知道PyTorch Hub中有哪些可用的计算机视觉模型:

    torch.hub.list('pytorch/vision')
    
    ['alexnet',
    'deeplabv3_resnet101',
    'densenet121',
    ...
    'vgg16',
    'vgg16_bn',
    'vgg19',
     'vgg19_bn']
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2、加载模型
    在上一步中能看到所有可用的计算机视觉模型,如果想调用其中的一个,也不必安装,只需一句话就能加载模型。

    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    
    • 1

    至于如何获得此模型的详细帮助信息,可以使用下面的API:

    print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
    
    • 1

    如果模型的发布者后续加入错误修复和性能改进,用户也可以非常简单地获取更新,确保自己用到的是最新版本:

    model = torch.hub.load(..., force_reload=True)
    
    • 1

    对于另外一部分用户来说,稳定性更加重要,他们有时候需要调用特定分支的代码。例如pytorch_GAN_zoo的hub分支:

    model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True)
    
    • 1

    3、查看模型可用方法
    从PyTorch Hub加载模型后,你可以用dir(model)查看模型的所有可用方法。以bertForMaskedLM模型为例:

    dir(model)
    
    ['forward'
    ...
    'to'
    'state_dict',
    ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    如果你对forward方法感兴趣,使用help(model.forward) 了解运行运行该方法所需的参数:

    help(model.forward)
    
    Help on method forward in module pytorch_pretrained_bert.modeling:
    forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    MySQL常用函数(聚合函数)
    智慧政务、数字化优先与数字机器人,政务领域正在开启“政务新视界”
    基于片段的分子生成网络 (FLAG)使用方法及案例测评
    如何制作精美的图片
    java计算机毕业设计HTML5互动游戏新闻网站设计与实现源码+mysql数据库+系统+lw文档+部署
    Java面试时,该如何准备亮点?
    4、创建第一个鸿蒙应用
    gtest从一无所知到熟练使用(4)如何用gtest写单元测试
    理论STL——vector篇(小Z 著)
    Texax Instruments 处理器资料导航(TI AM64x)
  • 原文地址:https://blog.csdn.net/zkp_987/article/details/126517116