经典文章xxx,理论就不介绍了,根据一个content图像,和一个style图像,可以把style图像的style迁移到content图像上。
在代码上有一个跟之前不同的地方,就是这里需要不断优化的变量是这张图像,vgg只是用来提取特征,不需要反传。具体做的时候,把三张图(content 图, style图,和我们希望生成的target图)都通过vgg,提取中间某些层计算出feature,这些feature 之间会计算一个content loss,使得target图和content图内容接近,同时计算一个style loss,使得target图和style图的style接近。
一般为了加速迭代,target会先用content图初始化。
不废话,上代码
import os
os.chdir(os.path.dirname(__file__))
from torchvision import models
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
sample_dir = 'samples_style_transfer'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir, exist_ok=True)
writer = SummaryWriter(sample_dir)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_image(image_path, transform=None, max_size=None, shape=None):
image = Image.open(image_path)
if max_size:
scale = max_size / max(image.size)
size = np.array(image.size) * scale
image = image.resize(size.astype(int), Image.ANTIALIAS)
if shape:
image = image.resize(shape, Image.LANCZOS)
if transform:
image = transform(image).unsqueeze(0)
return image.to(device)
class VGGNet(nn.Module):
def __init__(self):
super(VGGNet, self).__init__()
self.select = ['0', '5', '10', '19', '28']
self.vgg = models.vgg19(pretrained=True).features
def forward(self, x):
features = []
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.select:
features.append(x)
return features
def main(config):
T = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
content = load_image(config.content, T, max_size=config.max_size)
style = load_image(config.style, T, shape=[content.size(2), content.size(3)])
target = content.clone().requires_grad_(True)
optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()
for epoch in range(config.total_step):
target_feature = vgg(target)
content_feature = vgg(content)
style_feature = vgg(style)
style_loss = 0
content_loss = 0
for f1, f2, f3 in zip(target_feature, content_feature, style_feature):
content_loss += torch.mean((f1-f2)**2)
_,c,h,w = f1.size()
f1 = f1.view(c, h*w)
f3 = f3.view(c, h*w)
# gram matrix
f1 = torch.mm(f1, f1.t())
f3 = torch.mm(f3, f3.t())
# style loss
style_loss += torch.mean((f1-f3)**2) / (c*h*w)
loss = content_loss + config.style_weight * style_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), global_step=epoch)
writer.add_scalar('content_loss', content_loss.item(), global_step=epoch)
writer.add_scalar('style_loss', style_loss.item(), global_step=epoch)
if (epoch+1) % config.log_step == 0:
print('Epoch [{}/{}], Loss: {:.4f}, Content loss: {:.4f}, Style loss: {:.4f}'.\
format(epoch, config.total_step, loss.item(), content_loss.item(), style_loss.item()))
if (epoch+1) % config.sample_step == 0:
denorm = transforms.Normalize(mean=(-2.12, -2.04, -1.80), std=(4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0, 1)
writer.add_image('img', img, global_step=epoch)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--content', type=str, default='data/content.png')
parser.add_argument('--style', type=str, default='data/style.png')
parser.add_argument('--max_size', type=int, default=400)
parser.add_argument('--total_step', type=int, default=2000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--style_weight', type=float, default=100)
parser.add_argument('--lr', type=float, default=0.003)
config = parser.parse_args()
config.total_step = 20000
config.sample_step = 100
print(config)
main(config)
content图
style图
接近4000次迭代
经过 20000次迭代,能够看到风格越来越接近了。
但是这种style transfer有一个缺点,就是得针对一张图像进行不断迭代,能不能来了一张新图像,送进去后很快就能得到新的图像呢?
当然有,那就是fast neural style transfer
相关原理可以参考 https://blog.csdn.net/qq_33590958/article/details/96122789
相当于设计了一个transform 的网络,专门做风格转换,然后继续保留预训练好的vgg网络用来提取特征做loss,整体loss还是参考了原来的文章,有了这个transform网络,只需要训练好这个网络,新的图像输进去,一个前向过程就输出了转换后的图像,速度很快。不过也有个缺陷,那就是每个新的style图,都需要重新训练网络(我的理解是这样,有错的话欢迎指出)。
具体transform的网络结构如下,是全连接网络,大致分三个阶段,下采样-> 残差模块->上采样模块
代码是参考了pytorch的example代码,如下
import os
# os.chdir(os.path.dirname(__file__))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import argparse
sample_dir = 'samples_fast_style_transfer'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir, exist_ok=True)
writer = SummaryWriter(sample_dir)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)
def load_image(filename, size=None, scale=None):
img = Image.open(filename).convert('RGB')
if size is not None:
img = img.resize((size, size), Image.ANTIALIAS)
elif scale is not None:
size = (int(img.size[0] / scale), int(img.size[1] / scale))
img = img.resize(size, Image.ANTIALIAS)
return img
def save_image(filename, data):
img = data.clone.clamp(0, 255).numpy()
img = img.transpose(1,2,0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)
def gram_matrix(y):
b,ch,h,w = y.size()
features = y.view(b,ch,h*w)
features_t = features.transpose(1,2)
gram = features.bmm(features_t)/(ch*h*w)
return gram
def normalize_batch(batch):
# normalize using imagenet mean and std
mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
batch = batch.div_(255.0)
return (batch - mean) / std
def save_checkpoint(model, epochs):
model.eval().cpu()
save_model_filename = "epoch_" + str(epochs) + ".model"
save_model_path = os.path.join(sample_dir, save_model_filename)
torch.save(model.state_dict(), save_model_path)
print("\nDone, trained model saved at", save_model_path)
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ResidualBlock(torch.nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out
class UpsampleConvLayer(torch.nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) # 可以重点关注这个padding方式,参考https://blog.csdn.net/LionZYT/article/details/120181586
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
class TransformerNet(torch.nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
self.relu = torch.nn.ReLU()
def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y = self.relu(self.in4(self.deconv1(y)))
y = self.relu(self.in5(self.deconv2(y)))
y = self.deconv3(y)
return y
from collections import namedtuple
class Vgg16(nn.Module):
def __init__(self, required_grad=False):
super(Vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
if not required_grad:
for param in self.parameters():
param.requires_grad_(False)
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
return out
def train():
epochs = 10
batch_size = 4
image_size = 256
learning_rate = 1e-3
style_weight = 1e6
content_weight = 1e1
log_step = 10
dataset_path = 'data/fast_neural_style/dataset'
style_image = 'data/fast_neural_style/style-images/mosaic.jpg'
content_image = 'data/fast_neural_style/content-images/amber.jpg'
T = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x : x.mul(255))
])
train_dataset = datasets.ImageFolder(dataset_path, T)
train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True)
transformer = TransformerNet().to(device)
optimizer = torch.optim.Adam(transformer.parameters(), lr=learning_rate)
mse_loss = torch.nn.MSELoss()
vgg = Vgg16(required_grad=False).to(device)
# load style image
style_T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x :x.mul(255))
])
style = load_image(style_image, size = image_size)
style = style_T(style)
style = style.repeat(batch_size, 1,1,1).to(device)
features_style = vgg(normalize_batch(style))
gram_style = [gram_matrix(y) for y in features_style]
cnt = 0
# load content image
content_image = load_image(content_image)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
for epoch in range(epochs):
transformer.train()
for batchid, (x, _) in enumerate(train_loader):
x = x.to(device)
y = transformer(x)
# 这里是为了让数据的均值和方差符合预训练模型的分布
x = normalize_batch(x)
y = normalize_batch(y.mul(255))
features_x = vgg(x)
features_y = vgg(y)
content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2) * content_weight
style_loss = 0
for ft_y, gm_s in zip(features_y, gram_style):
gm_y = gram_matrix(ft_y)
style_loss += mse_loss(gm_y, gm_s)
style_loss = style_loss * style_weight
total_loss = content_loss + style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
cnt += 1
if cnt % log_step == 0:
print('Epoch [{}/{}], Step [{}], Loss: {:.4f}, Content loss: {:.4f}, Style loss: {:.4f}'.\
format(epoch, epochs, cnt, total_loss.item(), content_loss.item(), style_loss.item()))
writer.add_scalar('loss', total_loss.item(), global_step=cnt)
writer.add_scalar('content_loss', content_loss.item(), global_step=cnt)
writer.add_scalar('style_loss', style_loss.item(), global_step=cnt)
if cnt % 100 == 0:
img = eval(content_image, transformer)
writer.add_image('target_images', img, global_step=cnt, dataformats='CHW')
save_checkpoint(transformer, epoch)
def eval(content_image, transformer):
transformer.eval()
output_image = transformer(content_image).cpu()
transformer.train()
mean = output_image.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = output_image.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
output_image = output_image * std + mean
return output_image[0].clamp(0, 1)
if __name__ == '__main__':
train()
重点可以看看这段代码,按照代码的解释,这里上采样不是通过convTranspose来实现,而是先做一个ReflectionPad,再进行插值来上采样,然后做一个conv,按照注释的说法,这样效果比ConvTranspose2d好。
class UpsampleConvLayer(torch.nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) # 可以重点关注这个padding方式,参考https://blog.csdn.net/LionZYT/article/details/120181586
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
ReflectionPad 可以参考https://blog.csdn.net/LionZYT/article/details/120181586
函数用途:对输入图像以最外围像素为对称轴,做四周的轴对称镜像填充。
效果如下
我个人觉得风格也没那么接近,但是有点那个意思了,在训练过程中也是变得越来越清晰的。
这里我也贴出最后官方给的output的效果
这份代码里 style_weight 和content_weight 给的奇高,不知道是为了什么,我稍作了调整,可能影响了最终的训练,大家可以自行再调整一下。