目录
Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中。G的输入是x(x是需要转换的图片),输出是生成的图片G(x)。D则需要分辨出(x,G(x))和(x,y)

pix2pixGAN主要用于图像之间的转换,又称图像翻译。
对于图像翻译任务来说,输入和输出之间会共享很多信息。比如轮廓信息是共享的。如何解决共享问题?需要我们从损失函数的设计当中去思考。
如果使用普通的卷积神经网络,那么会导致每一层都承载保存着所有的信息。这样神经网络很容易出错(容易丢失一些信息)
所以,我们使用UNet模型作为生成器

D要输入成对的图像。这类似于cGAN,如果G(x)和x是对应的,对于生成器来说希望判别为1;
如果G(x)和x不是对应的,对于生成器来说希望判别器判别为0
pix2pixGAN中的D被论文中被实现为patch_D.所谓patch,是指无论生成的图片有多大,将其切分为多个固定大小的patch输入进D去判断。如上图所示。
这样设计的好处是:D的输入变小,计算量小,训练速度快
D网络损失函数:输入真实的成对图像希望判定为1;输入生成图像与原图希望判定为0
G网络损失函数:输入生成图像与原图像希望判定为1
公式如下图所示:

对于图像翻译任务而言,G的输入和输出之间其实共享了很多信息。因而为了保证输入图像和输出图像之间的相似度,还加入了L1loss,公式如下所示:
所以,结合两个公式,总的损失函数为

- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils import data
- import torchvision #加载图片
- from torchvision import transforms #图片变换
-
- import numpy as np
- import matplotlib.pyplot as plt #绘图
- import os
- import glob
- from PIL import Image
-
- imgs_path = glob.glob('base/*.jpg')
- annos_path = glob.glob('base/*.png')
- #预处理
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Resize((256,256)),
- transforms.Normalize(mean=0.5,std=0.5
- )
-
- ])
-
- #定义数据集
- class CMP_dataset(data.Dataset):
- def __init__(self,imgs_path,annos_path):
- self.imgs_path =imgs_path
- self.annos_path = annos_path
- def __getitem__(self,index):
- img_path = self.imgs_path[index]
- anno_path = self.annos_path[index]
- pil_img = Image.open(img_path) #读取数据
- pil_img = transform(pil_img) #转换数据
- anno_img = Image.open(anno_path) #读取数据
- anno_img = anno_img.convert("RGB")
- pil_anno = transform(anno_img) #转换数据
- return pil_anno,pil_img
- def __len__(self):
- return len(self.imgs_path)
-
- #创建数据集
- dataset = CMP_dataset(imgs_path,annos_path)
- #将数据转化为dataloader的格式,方便迭代
- BATCHSIZE = 32
- dataloader = data.DataLoader(dataset,
- batch_size = BATCHSIZE,
- shuffle = True)
- annos_batch,imgs_batch = next(iter(dataloader))
-
- #定义下采样模块
- class Downsample(nn.Module):
- def __init__(self,in_channels,out_channels):
- super(Downsample,self).__init__()
- self.conv_relu = nn.Sequential(
- nn.Conv2d(in_channels,out_channels,
- kernel_size=3,
- stride=2,
- padding=1),
- nn.LeakyReLU(inplace=True))
- self.bn = nn.BatchNorm2d(out_channels)
- def forward(self,x,is_bn=True):
- x=self.conv_relu(x)
- if is_bn:
- x=self.bn(x)
- return x
-
-
- #定义上采样模块
- class Upsample(nn.Module):
- def __init__(self,in_channels,out_channels):
- super(Upsample,self).__init__()
- self.upconv_relu = nn.Sequential(
- nn.ConvTranspose2d(in_channels,out_channels,
- kernel_size=3,
- stride=2,
- padding=1,
- output_padding=1), #反卷积,变为原来的2倍
- nn.LeakyReLU(inplace=True))
- self.bn = nn.BatchNorm2d(out_channels)
- def forward(self,x,is_drop=False):
- x=self.upconv_relu(x)
- x=self.bn(x)
- if is_drop:
- x=F.dropout2d(x)
- return x
-
-
- #定义生成器:包含6个下采样,5个上采样,一个输出层
- class Generator(nn.Module):
- def __init__(self):
- super(Generator,self).__init__()
- self.down1 = Downsample(3,64) #64,128,128
- self.down2 = Downsample(64,128) #128,64,64
- self.down3 = Downsample(128,256) #256,32,32
- self.down4 = Downsample(256,512) #512,16,16
- self.down5 = Downsample(512,512) #512,8,8
- self.down6 = Downsample(512,512) #512,4,4
-
- self.up1 = Upsample(512,512) #512,8,8
- self.up2 = Upsample(1024,512) #512,16,16
- self.up3 = Upsample(1024,256) #256,32,32
- self.up4 = Upsample(512,128) #128,64,64
- self.up5 = Upsample(256,64) #64,128,128
-
- self.last = nn.ConvTranspose2d(128,3,
- kernel_size=3,
- stride=2,
- padding=1,
- output_padding=1) #3,256,256
-
- def forward(self,x):
- x1 = self.down1(x)
- x2 = self.down2(x1)
- x3 = self.down3(x2)
- x4 = self.down4(x3)
- x5 = self.down5(x4)
- x6 = self.down6(x5)
-
- x6 = self.up1(x6,is_drop=True)
- x6 = torch.cat([x6,x5],dim=1)
-
- x6 = self.up2(x6,is_drop=True)
- x6 = torch.cat([x6,x4],dim=1)
-
- x6 = self.up3(x6,is_drop=True)
- x6 = torch.cat([x6,x3],dim=1)
-
- x6 = self.up4(x6)
- x6 = torch.cat([x6,x2],dim=1)
-
- x6 = self.up5(x6)
- x6 = torch.cat([x6,x1],dim=1)
-
-
- x6 = torch.tanh(self.last(x6))
- return x6
-
- #定义判别器 输入anno+img(生成或者真实) concat
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator,self).__init__()
- self.down1 = Downsample(6,64)
- self.down2 = Downsample(64,128)
- self.conv1 = nn.Conv2d(128,256,3)
- self.bn = nn.BatchNorm2d(256)
- self.last = nn.Conv2d(256,1,3)
- def forward(self,anno,img):
- x=torch.cat([anno,img],axis =1)
- x=self.down1(x,is_bn=False)
- x=self.down2(x,is_bn=True)
- x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
- x=torch.sigmoid(self.last(x)) #batch*1*60*60
- return x
-
- device = "cuda" if torch.cuda.is_available() else'cpu'
- gen = Generator().to(device)
- dis = Discriminator().to(device)
- d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
- g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
- #绘图
- def generate_images(model,test_anno,test_real):
- prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
- test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
- test_real = test_real.permute(0,2,3,1).cpu().numpy()
- plt.figure(figsize = (10,10))
- display_list = [test_anno[0],test_real[0],prediction[0]]
- title = ['Input','Ground Truth','Output']
- for i in range(3):
- plt.subplot(1,3,i+1)
- plt.title(title[i])
- plt.imshow(display_list[i])
- plt.axis('off') #坐标系关掉
- plt.show()
-
- test_imgs_path = glob.glob('extended/*.jpg')
- test_annos_path = glob.glob('extended/*.png')
-
- test_dataset = CMP_dataset(test_imgs_path,test_annos_path)
-
- test_dataloader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=BATCHSIZE,)
-
- #定义损失函数
- #cgan 损失函数
- loss_fn = torch.nn.BCELoss()
- #L1 loss
-
-
- annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
- LAMBDA = 7 #L1损失的权重
-
- D_loss = []#记录训练过程中判别器loss变化
- G_loss = []#记录训练过程中生成器loss变化
-
- #开始训练
- for epoch in range(10):
- D_epoch_loss = 0
- G_epoch_loss = 0
- count = len(dataloader)
- for step,(annos,imgs) in enumerate(dataloader):
- imgs = imgs.to(device)
- annos = annos.to(device)
- #定义判别器的损失计算以及优化的过程
- d_optimizer.zero_grad()
- disc_real_output = dis(annos,imgs)#输入真实成对图片
- d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
- device=device))
- d_real_loss.backward()
-
- gen_output = gen(annos)
- disc_gen_output = dis(annos,gen_output.detach())
- d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
- device=device))
- d_fack_loss.backward()
-
- disc_loss = d_real_loss+d_fack_loss#判别器的损失计算
- d_optimizer.step()
-
- #定义生成器的损失计算以及优化的过程
- g_optimizer.zero_grad()
- disc_gen_out = dis(annos,gen_output)
- gen_loss_crossentropyloss = loss_fn(disc_gen_out,
- torch.ones_like(disc_gen_out,
- device=device))
- gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
- gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
- gen_loss.backward() #反向传播
- g_optimizer.step() #优化
-
- #累计每一个批次的loss
- with torch.no_grad():
- D_epoch_loss +=disc_loss.item()
- G_epoch_loss +=gen_loss.item()
- #求平均损失
- with torch.no_grad():
- D_epoch_loss /=count
- G_epoch_loss /=count
- D_loss.append(D_epoch_loss)
- G_loss.append(G_epoch_loss)
- #训练完一个Epoch,打印提示并绘制生成的图片
- print("Epoch:",epoch)
- generate_images(gen,annos_batch,imgs_batch)