
microsoft/simMIM原碼:
https://github.com/microsoft/SimMIM/tree/main
microsoft/Swim-Transformer:
https://github.com/microsoft/Swin-Transformer/tree/main
上面兩個repo的模型應該是一樣的,模型使用models/simmim.py,因為都沒有提供視覺化的程式,我使用MAE的代碼改寫,另外Xiang Li等人的UM-MAE也有對應的代碼。
( config請根據模型自己建立一個dict )
( build_simmim(config) 從 models/simmim.py 拿)
simMIM(改寫輸出)
- class SimMIM(nn.Module):
- def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
- super().__init__()
- self.config = config
- self.encoder = encoder
- self.encoder_stride = encoder_stride
-
- self.decoder = nn.Sequential(
- nn.Conv2d(
- in_channels=self.encoder.num_features,
- out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
- nn.PixelShuffle(self.encoder_stride),
- )
-
- self.in_chans = in_chans
- self.patch_size = patch_size
-
- def forward(self, x, mask):
- z = self.encoder(x, mask)
- x_rec = self.decoder(z)
-
- mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
-
- # norm target as prompted
- if self.config['MODEL']['norm_target']:
- x = norm_targets(x, self.config['MODEL']['norm_patch_size'])
-
- loss_recon = F.l1_loss(x, x_rec, reduction='none')
- loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
- """
- 注意,這裡要改寫成回傳x_rec作為輸出
- """
- return x_rec,loss
Utilities
- from PIL import Image
- import matplotlib.pyplot as plt
- import torchvision.transforms as T
-
- #自行替換成符合你的資料集的mean與std
- image_mean = np.array([0.485, 0.456, 0.406])
- image_std = np.array([0.229, 0.224, 0.225])
-
- class MyTransform:
- def __init__(self, config, mask_ratio):
- self.transform_img = T.ToTensor()
- model_patch_size=config['MODEL']['patch_size']
- self.mask_generator = MaskGenerator(
- input_size=config['DATA']['input_size'],
- mask_patch_size=config['DATA']['mask_patch_size'],
- model_patch_size=model_patch_size,
- mask_ratio=mask_ratio,
- )
-
- def __call__(self, img):
- img = self.transform_img(img)
- mask = self.transform_img(self.mask_generator())
-
- return img, mask
-
- def show_image(image, title=''):
- # image is [H, W, 3]
- assert image.shape[2] == 3
- plt.imshow(torch.clip((image * image_std + image_mean) * 255, 0, 255).int())
- plt.title(title, fontsize=16)
- plt.axis('off')
- return
-
- def prepare_model(chkpt_dir):
- # build model
- model = build_simmim(config)
- # load model
- checkpoint = torch.load(chkpt_dir, map_location='cpu')
- rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
- for k in rpe_mlp_keys:
- checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
- msg = model.load_state_dict(checkpoint['model'], strict=False)
- print(msg)
- del checkpoint
- model.eval()
- return model
-
- def run_one_image(img, model,mask_ratio=0.65):
- #transform and mask
- transform = MyTransform(config, mask_ratio)
- x,mask = transform(img)
-
- # run simMIM
- y,_ = model(x.unsqueeze(dim=0).float(), mask)
- y = y.detach().squeeze(0)
- print(y.shape)
-
- # visualize the mask
- mask = mask.repeat_interleave(model.patch_size, 1).repeat_interleave(model.patch_size, 2).contiguous()
- im_masked = x * (1 - mask)
-
- # Reconstruction pasted with visible patches
- im_paste = x * (1 - mask) + y * mask
-
- # make the plt figure larger
- plt.rcParams['figure.figsize'] = [24, 24]
-
- plt.subplot(1, 4, 1)
- show_image(torch.einsum('chw->hwc', x), "original")
-
- plt.subplot(1, 4, 2)
- show_image(torch.einsum('chw->hwc', im_masked), "masked")
-
- plt.subplot(1, 4, 3)
- show_image(torch.einsum('chw->hwc', y), "reconstruction")
-
- plt.subplot(1, 4, 4)
- show_image(torch.einsum('chw->hwc', im_paste), "reconstruction + visible")
-
- plt.show()
- #如果有checkpoint自行替換路徑
- model = prepare_model('../out_dir/pretrain/simMIM_pt_base_192_w6-45.pth')
- #Prepare Image
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- #T.RandomCrop((192,192)),
- T.RandomResizedCrop((192,192)),
- ])
-
- #準備你要測試的圖片
- img = Image.open('../Fabrics/Quixel/001/oi2uhyp_2K_Roughness.jpg')
- img = transform(img)
- img = np.array(img) / 255.
-
- assert img.shape == (192, 192, 3)
-
- # normalize by ImageNet mean and std
- img = img - image_mean
- img = img / image_std
-
- plt.rcParams['figure.figsize'] = [3,3]
- show_image(torch.tensor(img))
- #Visualize
- torch.manual_seed(123456)
- print('simMIM with pixel reconstruction:')
- run_one_image(img,model,mask_ratio=0.65)

轉載請標記出處。
另外我有將程式碼改寫成單卡可在jupyter上跑的simMIM,如果有需要可以詢問。