• Torch生成类激活图CAM


    1. import torch
    2. from torch.nn import functional as F
    3. from torchvision import models, transforms
    4. from PIL import Image
    5. import os
    6. os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
    7. # 加载经过训练的 ResNet 模型
    8. model = models.resnet50(pretrained=True)
    9. model.eval()
    10. # 载入图像并进行预处理
    11. image_path = 'airline.png'
    12. image = Image.open(image_path).convert('RGB')
    13. preprocess = transforms.Compose([
    14. transforms.Resize((224, 224)),
    15. transforms.ToTensor(),
    16. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    17. ])
    18. input_tensor = preprocess(image).unsqueeze(0)
    19. # 前向传播获取特征图
    20. with torch.no_grad():
    21. features = model.conv1(input_tensor)
    22. features = model.layer1(features)
    23. features = model.layer2(features)
    24. features = model.layer3(features)
    25. features = model.layer4(features)
    26. # 获取模型的权重
    27. weight = model.fc.weight
    28. print(1)
    29. # 假设 cam 和 resized_tensor 是 PyTorch 张量
    30. # 将它们转换为 NumPy 数组
    31. import cv2
    32. bz, nc, h, w = features.shape
    33. beforeDot = features.reshape((nc, h*w))
    34. cam = torch.matmul(weight[1], beforeDot)#404
    35. cam = cam.reshape(h, w)
    36. size_upsample = (256, 256)
    37. cam = cam - torch.min(cam)
    38. cam_img = cam / torch.max(cam)
    39. # cam_img = torch.uint8(255 * cam_img)
    40. # import torch
    41. import torch.nn.functional as F
    42. # 使用 interpolate 函数将其调整为 [224, 224]
    43. resized_tensor = F.interpolate(cam_img.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
    44. # 现在 resized_tensor 是一个大小为 [1, 1, 224, 224] 的 PyTorch 张量
    45. # 如果需要,你可以使用 .squeeze() 方法来移除不必要的维度
    46. output_cam = resized_tensor.squeeze()
    47. import numpy as np
    48. cam_np = output_cam.detach().numpy()
    49. # 假设 image 是你的图像数据
    50. # cam_np = cam_np.astype(np.uint8)
    51. resized_tensor_np = input_tensor.detach().numpy()
    52. # 将 image 的形状调整为 (3, 224, 224)
    53. image = resized_tensor_np.squeeze()
    54. # 转换图像通道顺序,从 (3, 224, 224) 调整为 (224, 224, 3)
    55. image = np.transpose(image, (1, 2, 0))
    56. import matplotlib.pyplot as plt
    57. # 创建一个新的图形
    58. plt.figure(figsize=(8, 8))
    59. # 绘制原始图像
    60. plt.subplot(1, 2, 1)
    61. plt.imshow(image)#, cmap='gray')
    62. plt.title('Original Image')
    63. # 绘制 CAM
    64. plt.subplot(1, 2, 2)
    65. plt.imshow(cam_np, cmap='jet') # 使用 'jet' 颜色映射以突出 CAM
    66. plt.title('Class Activation Map (CAM)')
    67. # 显示图形
    68. plt.show()

  • 相关阅读:
    从0到1实现五子棋游戏!!
    leetcode - 42. Trapping Rain Water
    vue网页浏览器刷新404问题解决
    如何进行前端单元测试?
    no main manifest attribute, in xxx.jar
    信息系统漏洞与风险管理制度
    网络编程套接字(2)——简单的TCP网络程序
    Attention机制学习记录(四)之Transformer
    Vue组件的渲染更新原理知识大连串
    C++模板初阶 —— 函数模板、类模板、模板的声明和定义分离(多文件使用的注意事项)
  • 原文地址:https://blog.csdn.net/qq_34069180/article/details/133816213