• 【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx


    linux执行这个程序:

    import torch
    import torch.onnx
    from torchvision import transforms, models
    from PIL import Image
    import os
    
    # Load MobileNetV2 model
    model = models.mobilenet_v2(pretrained=True)
    model.eval()
    
    # Download an example image from the PyTorch website
    url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
    try:
        os.system(f"wget {url} -O {filename}")
    except Exception as e:
        print(f"Error downloading image: {e}")
    
    # Preprocess the input image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_image = Image.open(filename)
    input_tensor = preprocess(input_image)
    input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    
    # Perform inference on CPU
    with torch.no_grad():
        output = model(input_tensor)
    
    # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
    print(output[0])
    
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    print(probabilities)
    
    # Download ImageNet labels using wget
    os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
    
    # Read the categories
    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    for i in range(top5_prob.size(0)):
        print(categories[top5_catid[i]], top5_prob[i].item())
    
    # Save the PyTorch model
    torch.save(model.state_dict(), "mobilenet_v2.pth")
    
    # Convert the PyTorch model to ONNX with specified input and output names
    dummy_input = torch.randn(1, 3, 224, 224)
    onnx_path = "mobilenet_v2.onnx"
    input_names = ['input']
    output_names = ['output']
    torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)
    
    print(f"PyTorch model saved to 'mobilenet_v2.pth'")
    print(f"ONNX model saved to '{onnx_path}'")
    
    # Load the ONNX model
    import onnx
    import onnxruntime
    
    onnx_model = onnx.load(onnx_path)
    onnx_session = onnxruntime.InferenceSession(onnx_path)
    
    # Convert input tensor to ONNX-compatible format
    input_tensor_onnx = input_tensor.numpy()
    
    # Perform inference on ONNX with the correct input name
    onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
    onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)
    
    # Show top categories per image for ONNX
    onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
    print("\nTop categories for ONNX:")
    for i in range(onnx_top5_prob.size(1)):
        print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    得到:

    在这里插入图片描述
    用本地pth推理:

    import torch
    from torchvision import transforms, models
    from PIL import Image
    
    # Load MobileNetV2 model
    model = models.mobilenet_v2()
    model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
    model.eval()
    
    # Preprocess the input image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Load the example image
    input_image = Image.open("dog.jpg")
    input_tensor = preprocess(input_image)
    input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    
    # Perform inference on CPU
    with torch.no_grad():
        output = model(input_tensor)
    
    # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
    # print(output[0])
    
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    # Load ImageNet labels
    categories = []
    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    for i in range(top5_prob.size(0)):
        print(categories[top5_catid[i]], top5_prob[i].item())
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    onnx推理:

    import torch
    import onnxruntime
    from torchvision import transforms
    from PIL import Image
    
    # Load the ONNX model
    onnx_path = "mobilenet_v2.onnx"
    onnx_session = onnxruntime.InferenceSession(onnx_path)
    
    # Preprocess the input image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Load the example image
    input_image = Image.open("dog.jpg")
    input_tensor = preprocess(input_image)
    input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    
    # Convert input tensor to ONNX-compatible format
    input_tensor_onnx = input_tensor.numpy()
    
    # Perform inference on ONNX
    onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
    onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)
    
    # Load ImageNet labels
    categories = []
    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    
    # Show top categories per image for ONNX
    onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
    print("Top categories for ONNX:")
    for i in range(onnx_top5_prob.size(1)):
        print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
  • 相关阅读:
    10. 一文快速学懂常用工具——网络工具(上)
    Prompt 驱动架构设计:探索复杂 AIGC 应用的设计之道?
    轻量级实时跟踪算法NanoTrack在瑞芯微RK3588上的部署以及使用
    WebAPI中使用WebService后发布注意要点
    昇腾CANN 7.0 黑科技:大模型训练性能优化之道
    金山办公推出WPS AI,开放应用于智能文档
    lazarus:数据集快速导出为excel、csv、sql及其他多种格式
    【微机原理笔记】第 3 章 - 8086/8088的指令系统
    修改git文件
    Javascript异步编程深入浅出
  • 原文地址:https://blog.csdn.net/x1131230123/article/details/134462062