• torch_vision(二):模型和预训练weight模块 torchvision.models


    torchvision.models简单介绍

    介绍

    torchvision.models模块提供了很多模型架构,以及对应的预先训练好的权重。
    最新的版本的特性是相比于旧版本

    1. 一个模型架构可以加载多种不同的权重。
    2. 可以获取到预处理方法,这些转换中的任何细微差异(例如插值、调整大小/裁剪大小等)都可能导致准确性大幅降低或模型无法使用。
    3. 提供元数据,包括类别标签,准确度等指标。

    以一个分类模型为例:

    from torchvision.io import read_image
    from torchvision.models import resnet50, ResNet50_Weights
    
    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
    
    # Step 1: Initialize model with the best available weights
    # ResNet50_Weights.IMAGENET1K_V1  ResNet50_Weights.IMAGENET1K_V2是其他可以选择的版本,DEFAULT一般是最优的版本
    weights = ResNet50_Weights.DEFAULT 
    model = resnet50(weights=weights)
    model.eval()
    
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
    
    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    目标检测

    from torchvision.io.image import read_image
    from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
    from torchvision.utils import draw_bounding_boxes
    from torchvision.transforms.functional import to_pil_image
    
    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
    
    # Step 1: Initialize model with the best available weights
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
    model.eval()
    
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
    
    # Step 3: Apply inference preprocessing transforms
    batch = [preprocess(img)]
    
    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)[0]
    labels = [weights.meta["categories"][i] for i in prediction["labels"]]
    box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                              labels=labels,
                              colors="red",
                              width=4, font_size=30)
    im = to_pil_image(box.detach())
    im.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    语义分割

    from torchvision.io.image import read_image
    from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
    from torchvision.transforms.functional import to_pil_image
    
    img = read_image("gallery/assets/dog1.jpg")
    
    # Step 1: Initialize model with the best available weights
    weights = FCN_ResNet50_Weights.DEFAULT
    model = fcn_resnet50(weights=weights)
    model.eval()
    
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
    
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
    
    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)["out"]
    normalized_masks = prediction.softmax(dim=1)
    class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
    mask = normalized_masks[0, class_to_idx["dog"]]
    to_pil_image(mask).show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    可以选择的模型和weight

    torchvision.models包含很多模型和预先训练好的weight, 能够处理多种任务,图像分类,语义分割,目标检测,关键点检测,视频分类,光流估计等。

    The torchvision.models subpackage contains definitions of models for addressing different tasks, 
    including:image classification, pixelwise semantic segmentation, object detection, instance segmentation, 
    person keypoint detection, video classification, and optical flow.
    
    • 1
    • 2
    • 3

    具体各个任务有哪些可以在torchvision.models可以获取到的模型,请查看

    MODELS AND PRE-TRAINED WEIGHTS

    其实覆盖的模型不算多,超分,生成模型,图像修复,图像增强等多种任务并没有相关模型在torchvision.models中。

  • 相关阅读:
    最短路(spfa)hdu 2544
    腾讯云原生安全“3+1”一体化方案发布,重构云上安全防御体系
    SpringMvc(四、统一异常处理
    基于Unity3D的PC&Android端2D横屏冒险类闯关游戏
    Linux之文件打包和解压缩
    Django高级表单处理与验证实战
    你居然还去服务器上捞日志,搭个日志收集系统难道不香吗?
    学c语言可以过CCT里的c++吗?
    目标检测YOLO实战应用案例100讲-基于YOLOv3多模块融合的遥感目标检测
    华为机试真题 C++ 实现【批量处理任务】
  • 原文地址:https://blog.csdn.net/tywwwww/article/details/127414132