• microsoft/simMIM-visualize 視覺化(原創)


    microsoft/simMIM原碼:
    https://github.com/microsoft/SimMIM/tree/main​​​​​​

    microsoft/Swim-Transformer:

    https://github.com/microsoft/Swin-Transformer/tree/main

    上面兩個repo的模型應該是一樣的,模型使用models/simmim.py,因為都沒有提供視覺化的程式,我使用MAE的代碼改寫,另外Xiang Li等人的UM-MAE也有對應的代碼。

    ( config請根據模型自己建立一個dict )

    ( build_simmim(config) 從 models/simmim.py 拿)

    simMIM(改寫輸出)

    1. class SimMIM(nn.Module):
    2. def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
    3. super().__init__()
    4. self.config = config
    5. self.encoder = encoder
    6. self.encoder_stride = encoder_stride
    7. self.decoder = nn.Sequential(
    8. nn.Conv2d(
    9. in_channels=self.encoder.num_features,
    10. out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
    11. nn.PixelShuffle(self.encoder_stride),
    12. )
    13. self.in_chans = in_chans
    14. self.patch_size = patch_size
    15. def forward(self, x, mask):
    16. z = self.encoder(x, mask)
    17. x_rec = self.decoder(z)
    18. mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
    19. # norm target as prompted
    20. if self.config['MODEL']['norm_target']:
    21. x = norm_targets(x, self.config['MODEL']['norm_patch_size'])
    22. loss_recon = F.l1_loss(x, x_rec, reduction='none')
    23. loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
    24. """
    25. 注意,這裡要改寫成回傳x_rec作為輸出
    26. """
    27. return x_rec,loss

    Utilities

    1. from PIL import Image
    2. import matplotlib.pyplot as plt
    3. import torchvision.transforms as T
    4. #自行替換成符合你的資料集的mean與std
    5. image_mean = np.array([0.485, 0.456, 0.406])
    6. image_std = np.array([0.229, 0.224, 0.225])
    7. class MyTransform:
    8. def __init__(self, config, mask_ratio):
    9. self.transform_img = T.ToTensor()
    10. model_patch_size=config['MODEL']['patch_size']
    11. self.mask_generator = MaskGenerator(
    12. input_size=config['DATA']['input_size'],
    13. mask_patch_size=config['DATA']['mask_patch_size'],
    14. model_patch_size=model_patch_size,
    15. mask_ratio=mask_ratio,
    16. )
    17. def __call__(self, img):
    18. img = self.transform_img(img)
    19. mask = self.transform_img(self.mask_generator())
    20. return img, mask
    21. def show_image(image, title=''):
    22. # image is [H, W, 3]
    23. assert image.shape[2] == 3
    24. plt.imshow(torch.clip((image * image_std + image_mean) * 255, 0, 255).int())
    25. plt.title(title, fontsize=16)
    26. plt.axis('off')
    27. return
    28. def prepare_model(chkpt_dir):
    29. # build model
    30. model = build_simmim(config)
    31. # load model
    32. checkpoint = torch.load(chkpt_dir, map_location='cpu')
    33. rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
    34. for k in rpe_mlp_keys:
    35. checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
    36. msg = model.load_state_dict(checkpoint['model'], strict=False)
    37. print(msg)
    38. del checkpoint
    39. model.eval()
    40. return model
    41. def run_one_image(img, model,mask_ratio=0.65):
    42. #transform and mask
    43. transform = MyTransform(config, mask_ratio)
    44. x,mask = transform(img)
    45. # run simMIM
    46. y,_ = model(x.unsqueeze(dim=0).float(), mask)
    47. y = y.detach().squeeze(0)
    48. print(y.shape)
    49. # visualize the mask
    50. mask = mask.repeat_interleave(model.patch_size, 1).repeat_interleave(model.patch_size, 2).contiguous()
    51. im_masked = x * (1 - mask)
    52. # Reconstruction pasted with visible patches
    53. im_paste = x * (1 - mask) + y * mask
    54. # make the plt figure larger
    55. plt.rcParams['figure.figsize'] = [24, 24]
    56. plt.subplot(1, 4, 1)
    57. show_image(torch.einsum('chw->hwc', x), "original")
    58. plt.subplot(1, 4, 2)
    59. show_image(torch.einsum('chw->hwc', im_masked), "masked")
    60. plt.subplot(1, 4, 3)
    61. show_image(torch.einsum('chw->hwc', y), "reconstruction")
    62. plt.subplot(1, 4, 4)
    63. show_image(torch.einsum('chw->hwc', im_paste), "reconstruction + visible")
    64. plt.show()
    1. #如果有checkpoint自行替換路徑
    2. model = prepare_model('../out_dir/pretrain/simMIM_pt_base_192_w6-45.pth')
    1. #Prepare Image
    2. transform = T.Compose([
    3. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    4. #T.RandomCrop((192,192)),
    5. T.RandomResizedCrop((192,192)),
    6. ])
    7. #準備你要測試的圖片
    8. img = Image.open('../Fabrics/Quixel/001/oi2uhyp_2K_Roughness.jpg')
    9. img = transform(img)
    10. img = np.array(img) / 255.
    11. assert img.shape == (192, 192, 3)
    12. # normalize by ImageNet mean and std
    13. img = img - image_mean
    14. img = img / image_std
    15. plt.rcParams['figure.figsize'] = [3,3]
    16. show_image(torch.tensor(img))
    1. #Visualize
    2. torch.manual_seed(123456)
    3. print('simMIM with pixel reconstruction:')
    4. run_one_image(img,model,mask_ratio=0.65)

    轉載請標記出處。

    另外我有將程式碼改寫成單卡可在jupyter上跑的simMIM,如果有需要可以詢問。

  • 相关阅读:
    Java开发学习(二十一)----Spring事务简介与事务角色解析
    C语言中的结构体和联合体有什么区别?
    Codeforces Round 908 (Div. 2)视频详解
    Unity Shader Graph 风格化熔岩
    实例分割最全综述(入坑一载半,退坑止于此)
    在listener.ora配置文件中配置listener 1527的监听并且使用tnsnames连接测试
    网页vue3导出pdf
    【构建并发程序】1-线程池-Executor-ExecutionContext
    2.5python 循环_python量化实用版教程(初级)
    洛谷P2680 树上路径,差分,二分答案
  • 原文地址:https://blog.csdn.net/weixin_44228592/article/details/132624765