https://github.com/huynguyenbao/Practical-Deep-Raw-Image-Denoising-on-Mobile-Devices
中的网络模型是处理raw图的 4个通道
修改为处理rgb图的三个通道,如果要想加载作者提供的pretraine weight, 则需要修改一下
不匹配的层的参数。
if __name__ == "__main__":
net = Network(3, 3)
# img = mge.tensor(np.random.randn(1, 4, 64, 64).astype(np.float32))
img = torch.randn(1, 3, 800, 800, device=torch.device('cpu'), dtype=torch.float32)
out = net(img)
# 修改预训练参数
model_path =r'D:\code\denoise\PMRID-main\models\torch_pretrained.ckp'
states = torch.load(model_path)
for name, value in states.items():
print(name, value.shape)
if name == 'conv0.conv.weight':
states[name] = value[:,:3,:,:]
if name == 'out1.conv.weight':
states[name] = value[:3,:,:,:]
if name == 'out1.conv.bias':
states[name] = value[:3]
print(name, states[name].shape)
net.load_state_dict(states)
net.eval()
for name, value in net.named_parameters():
print(name, value.shape)
torch.save(net.state_dict(), 'pmrid_rgb.pth')