原理参考 https://zhuanlan.zhihu.com/p/464673225
代码参考自 https://github.com/LibreCV/blog/blob/master/_notebooks/2021-02-13-Pix2Pix%20explained%20with%20code.ipynb
import os
# os.chdir(os.path.dirname(__file__))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import argparse
from glob import glob
import random
import itertools
sample_dir = 'samples_pix2pix'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir, exist_ok=True)
writer = SummaryWriter(sample_dir)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)
class DownSampleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
"""
Paper details:
- C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
"""
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)
if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)
if activation:
self.act = nn.LeakyReLU(0.2)
def forward(self, x):
x = self.conv(x)
if self.batchnorm:
x = self.bn(x)
if self.activation:
x = self.act(x)
return x
class UpSampleConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel=4,
strides=2,
padding=1,
activation=True,
batchnorm=True,
dropout=False
):
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
self.dropout = dropout
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)
if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)
if activation:
self.act = nn.ReLU(True)
if dropout:
self.drop = nn.Dropout2d(0.5)
def forward(self, x):
x = self.deconv(x)
if self.batchnorm:
x = self.bn(x)
if self.dropout:
x = self.drop(x)
return x
class Generator(nn.Module):
def __init__(self, in_channels, out_channels):
"""
Paper details:
- Encoder: C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
"""
super().__init__()
# encoder/donwsample convs
self.encoders = [
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128
DownSampleConv(64, 128), # bs x 128 x 64 x 64
DownSampleConv(128, 256), # bs x 256 x 32 x 32
DownSampleConv(256, 512), # bs x 512 x 16 x 16
DownSampleConv(512, 512), # bs x 512 x 8 x 8
DownSampleConv(512, 512), # bs x 512 x 4 x 4
DownSampleConv(512, 512), # bs x 512 x 2 x 2
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1
]
# decoder/upsample convs
self.decoders = [
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8
UpSampleConv(1024, 512), # bs x 512 x 16 x 16
UpSampleConv(1024, 256), # bs x 256 x 32 x 32
UpSampleConv(512, 128), # bs x 128 x 64 x 64
UpSampleConv(256, 64), # bs x 64 x 128 x 128
]
self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
self.tanh = nn.Tanh()
self.encoders = nn.ModuleList(self.encoders)
self.decoders = nn.ModuleList(self.decoders)
def forward(self, x):
skips_cons = []
for encoder in self.encoders:
x = encoder(x)
skips_cons.append(x)
skips_cons = list(reversed(skips_cons[:-1]))
decoders = self.decoders[:-1]
for decoder, skip in zip(decoders, skips_cons):
x = decoder(x)
# print(x.shape, skip.shape)
x = torch.cat((x, skip), axis=1)
x = self.decoders[-1](x)
# print(x.shape)
x = self.final_conv(x)
return self.tanh(x)
class PatchGAN(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
self.d2 = DownSampleConv(64, 128)
self.d3 = DownSampleConv(128, 256)
self.d4 = DownSampleConv(256, 512)
self.final = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.d1(x)
x1 = self.d2(x0)
x2 = self.d3(x1)
x3 = self.d4(x2)
xn = self.final(x3)
return xn
def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms=None, unaligned=False, mode='train'):
self.transforms = transforms
self.unaligned = unaligned
self.files_A = sorted(glob(os.path.join(root, mode, 'A', '*.*')))
self.files_B = sorted(glob(os.path.join(root, mode, 'B', '*.*')))
def __getitem__(self, idx):
img = Image.open(self.files_A[idx % len(self.files_A)]).convert('RGB')
itemA = self.transforms(img)
if self.unaligned:
rand_idx = random.randint(0, len(self.files_B)-1)
img = Image.open(self.files_B[rand_idx]).convert('RGB')
itemB = self.transforms(img)
else:
img = Image.open(self.files_B[idx % len(self.files_B)]).convert('RGB')
itemB = self.transforms(img)
return {
'A' : itemA,
'B' : itemB
}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
def denorm(x):
out = (x+1)/2
return out.clamp(0, 1)
# Losses
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200
n_epochs = 200
display_step = 100
batch_size = 4
lr = 0.0002
target_size = 256
input_size = 256
dataroot = 'data/cycle_gan/datasets/facades'
input_nc = 3
output_nc = 3
G = Generator(input_nc, output_nc).to(device)
D = PatchGAN(input_nc + output_nc).to(device)
G.apply(_weights_init)
D.apply(_weights_init)
optimG = torch.optim.Adam(G.parameters(), lr=lr)
optimD = torch.optim.Adam(D.parameters(), lr=lr)
# Dataset loader
transforms_data = transforms.Compose([
transforms.Resize(int(input_size*1.12), Image.BICUBIC),
transforms.RandomCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
dataset = ImageDataset(dataroot, transforms=transforms_data, unaligned=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True)
###### Training ######
cnt = 0
log_step = 10
for epoch in range(0, n_epochs):
for i, batch in enumerate(dataloader):
# set model input
real = batch['A'].to(device)
condition = batch['B'].to(device)
# discriminator
fake_images = G(condition).detach()
fake_logits = D(fake_images, condition)
real_logits = D(real, condition)
fake_loss = adv_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = adv_criterion(real_logits, torch.ones_like(real_logits))
d_loss = (real_loss + fake_loss) / 2
optimD.zero_grad()
d_loss.backward()
optimD.step()
# generator
fake_images = G(condition)
disc_logits = D(fake_images, condition)
adversarial_loss = adv_criterion(disc_logits, torch.ones_like(disc_logits))
# calculate reconstruction loss
recon_loss = recon_criterion(fake_images, real)
g_loss = adversarial_loss + lambda_recon * recon_loss
optimG.zero_grad()
g_loss.backward()
optimG.step()
cnt += 1
if cnt % log_step == 0:
print('Epoch [{}/{}], Step [{}], g_loss: {:.4f}, d_loss: {:.4f}'.\
format(epoch, n_epochs, cnt, g_loss.item(), d_loss.item()))
writer.add_scalar('g_loss', g_loss.item(), global_step=cnt)
writer.add_scalar('d_loss', d_loss.item(), global_step=cnt)
if cnt % 100 == 0:
writer.add_images('real', denorm(real), global_step=cnt)
writer.add_images('condition', denorm(condition), global_step=cnt)
writer.add_images('fake_images', denorm(fake_images), global_step=cnt)
整体结构参考自Conditional GAN,把图像A作为condition出现在generator和discriminator里。
另外一个可以关注一下U-net结构的generator设计,和PatchGAN结构的 discriminator。具体解释可以看下图
实验效果如下
real image
condition image
generated image,效果很差,可能是没有训练到位,后续再调试吧