BargainNet是bcmi的一个项目。具体项目介绍见GitHub链接。出于各种原因需要使用BargainNet,因为有些不习惯用命令行启动训练模型,所以将里面使用的默认模型、参数直接提取出来,简化成了简单的“读取数据”和“训练模型”两个文件。
训练数据的文件结构如下(去不掉水印我也很烦):

IHD_train.txt的结构很简单,就是文件列表而已:

其他的就是读取数据部分的代码和模型代码放在同一文件夹,改一下读取数据代码里数据集的路径就可以
文件名为:HarmonyDataset.py,方便模型导入
import os.path
import random
from abc import ABC
import cv2.cv2 as cv2
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform, ToGray
class HCompose(Compose):
def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs):
if additional_targets is None:
additional_targets = {
'real': 'image',
'mask': 'mask'
}
self.additional_targets = additional_targets
super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs)
if no_nearest_for_masks:
for t in transforms:
if isinstance(t, DualTransform):
t._additional_targets['mask'] = 'image'
def get_transform(params=None, no_flip=True, grayscale=False):
transform_list = []
if grayscale:
transform_list.append(ToGray())
if params is None:
transform_list.append(RandomResizedCrop(512, 512, scale=(0.5, 1.0)))
if not no_flip:
if params is None:
transform_list.append(HorizontalFlip())
return HCompose(transform_list)
class Iharmony4Dataset(data.Dataset, ABC):
def __init__(self, dataset_root,):
self.image_paths = []
print('loading training file: ')
self.keep_background_prob = 0.05
self.file = dataset_root.replace("com", "") + 'IHD_train.txt'
with open(self.file, 'r') as f:
for line in f.readlines():
self.image_paths.append(os.path.join(dataset_root, line.rstrip()))
self.transform = get_transform()
self.input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
def __getitem__(self, index):
sample = self.get_sample(index)
self.check_sample_types(sample)
sample = self.augment_sample(sample)
comp = self.input_transform(sample['image'])
real = self.input_transform(sample['real'])
mask = sample['mask'].astype(np.float32)
mask = mask[np.newaxis, ...].astype(np.float32)
output = {
'comp': comp.unsqueeze(0),
'mask': torch.from_numpy(mask).unsqueeze(0),
'real': real.unsqueeze(0),
'img_path': sample['img_path']
}
return output
def check_sample_types(self, sample):
assert sample['comp'].dtype == 'uint8'
if 'real' in sample:
assert sample['real'].dtype == 'uint8'
def augment_sample(self, sample):
if self.transform is None:
return sample
additional_targets = {target_name: sample[target_name]
for target_name in self.transform.additional_targets.keys()}
valid_augmentation = False
while not valid_augmentation:
aug_output = self.transform(image=sample['comp'], **additional_targets)
valid_augmentation = self.check_augmented_sample(aug_output)
for target_name, transformed_target in aug_output.items():
sample[target_name] = transformed_target
return sample
def check_augmented_sample(self, aug_output):
if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob:
return True
return aug_output['mask'].sum() > 1.0
def get_sample(self, index):
path = self.image_paths[index]
name_parts = path.split('_')
mask_path = self.image_paths[index].replace('com', 'mask')
mask_path = mask_path.replace(('_' + name_parts[-1]), '.png')
target_path = self.image_paths[index].replace('com', 'gt')
target_path = target_path.replace(('_' + name_parts[-1]), '.png')
comp = cv2.imread(path)
comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
real = cv2.imread(target_path)
real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path)
mask = mask[:, :, 0].astype(np.float32) / 255.
mask = mask.astype(np.uint8)
return {'comp': comp, 'mask': mask, 'real': real, 'img_path': path}
def __len__(self):
return len(self.image_paths)
comp为合成后的图片—————————— mask为合成区域的mask——————real为groundtrue

目标自然就是让comp -> real了
叫啥都行
import functools
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from torch.nn import init
from torch.optim import lr_scheduler
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
use_attention=False):
super(UnetGenerator, self).__init__()
# construct unet structure
weight = torch.FloatTensor([0.1])
self.weight = torch.nn.Parameter(weight, requires_grad=True)
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_attention=use_attention)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_attention=use_attention)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
use_attention=use_attention)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer) # add the outermost layer
def forward(self, inputs):
ori_code_map = inputs[:, 4:, :, :]
code_map_input = ori_code_map * torch.clamp(self.weight, min=0.001)
mew_inputs = torch.cat([inputs[:, :4, :, :], code_map_input], 1)
return self.model(mew_inputs)
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,
use_attention=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.use_attention = use_attention
if use_attention:
attention_conv = nn.Conv2d(outer_nc + input_nc, outer_nc + input_nc, kernel_size=1)
attention_sigmoid = nn.Sigmoid()
self.attention = nn.Sequential(*[attention_conv, attention_sigmoid])
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
ret = torch.cat([x, self.model(x)], 1)
return self.attention(ret) * ret if self.use_attention else ret
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
self.return_mask = True
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
self.kernel_size[1])
else:
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \
self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask, self.mask_ratio = None, None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2],
input.data.shape[3]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
padding=self.padding, dilation=self.dilation, groups=1)
self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class StyleEncoder(nn.Module):
def __init__(self, style_dim, norm_layer=nn.BatchNorm2d):
super(StyleEncoder, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
ndf = 64
kw = 3
padw = 0
self.conv1f = PartialConv2d(3, ndf, kernel_size=kw, stride=2, padding=padw)
self.relu1 = nn.ReLU(True)
nf_mult = 1
n = 1
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.conv2f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
bias=use_bias)
self.norm2f = norm_layer(ndf * nf_mult)
self.relu2 = nn.ReLU(True)
n = 2
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.conv3f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
bias=use_bias)
self.norm3f = norm_layer(ndf * nf_mult)
self.relu3 = nn.ReLU(True)
n = 3
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.conv4f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
bias=use_bias)
self.norm4f = norm_layer(ndf * nf_mult)
self.relu4 = nn.ReLU(True)
n = 4
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.conv5f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
bias=use_bias)
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
self.convs = nn.Conv2d(ndf * nf_mult, style_dim, kernel_size=1, stride=1)
def forward(self, input, mask):
"""Standard forward."""
xb = input
mb = mask
xb, mb = self.conv1f(xb, mb)
xb = self.relu1(xb)
xb, mb = self.conv2f(xb, mb)
xb = self.norm2f(xb)
xb = self.relu2(xb)
xb, mb = self.conv3f(xb, mb)
xb = self.norm3f(xb)
xb = self.relu3(xb)
xb, mb = self.conv4f(xb, mb)
xb = self.norm4f(xb)
xb = self.relu4(xb)
xb, mb = self.conv5f(xb, mb)
xb = self.avg_pooling(xb)
s = self.convs(xb)
return s
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find(
'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
class BargainNetModel:
def __init__(self, netE, netG, style_dim=16, img_size=512, init_type='normal', init_gain=0.02, gpu_ids=[]):
self.gpu_ids = gpu_ids
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.lambda_tri = 0.01
self.lambda_f2b = 1.0
self.lambda_ff2 = 1.0
self.loss_names = ['L1', 'tri']
self.optimizers = []
self.lr = 0.0002
self.e_lr_ratio = 1.0
self.g_lr_ratio = 1.0
self.beta1 = 0.5
self.style_dim = style_dim
self.image_size = img_size
self.netE = init_net(netE, init_type, init_gain, self.gpu_ids)
self.netG = init_net(netG, init_type, init_gain, self.gpu_ids)
self.relu = nn.ReLU()
self.margin = 0.1
self.tripletLoss = nn.TripletMarginLoss(margin=self.margin, p=2)
self.criterionL1 = torch.nn.L1Loss()
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.lr * self.e_lr_ratio,
betas=(self.beta1, 0.999))
self.optimizers.append(self.optimizer_E)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr * self.g_lr_ratio,
betas=(self.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.schedulers = [
lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0) for optimizer in self.optimizers
]
def set_input(self, input):
self.comp = input['comp'].to(self.device)
self.real = input['real'].to(self.device)
self.mask = input['mask'].to(self.device)
self.inputs = torch.cat([self.comp, self.mask], 1).to(self.device)
self.bg = 1.0 - self.mask
self.real_f = self.real * self.mask
def forward(self):
self.bg_sty_vector = self.netE(self.real, self.bg)
self.real_fg_sty_vector = self.netE(self.real, self.mask)
self.bg_sty_map = self.bg_sty_vector.expand([1, self.style_dim, self.image_size, self.image_size])
self.inputs_c2r = torch.cat([self.inputs, self.bg_sty_map], 1)
self.harm = self.netG(self.inputs_c2r)
self.harm_fg_sty_vector = self.netE(self.harm, self.mask)
self.comp_fg_sty_vector = self.netE(self.comp, self.mask)
self.fake_f = self.harm * self.mask
def backward(self):
self.loss_L1 = self.criterionL1(self.harm, self.real)
self.loss_tri = (self.tripletLoss(self.real_fg_sty_vector, self.harm_fg_sty_vector,
self.comp_fg_sty_vector) * self.lambda_ff2
+ self.tripletLoss(self.harm_fg_sty_vector, self.bg_sty_vector,
self.comp_fg_sty_vector) * self.lambda_f2b) * self.lambda_tri
self.loss = self.loss_L1 + self.loss_tri
self.loss.backward(retain_graph=True)
def optimize_parameters(self):
self.forward()
self.optimizer_E.zero_grad()
self.optimizer_G.zero_grad()
self.backward()
self.optimizer_E.step()
self.optimizer_G.step()
主要改动是这里,给原来的训练方式加上了tqdm的进度条,现在可以在进度条上看到[“l1_loss”, “tri_loss”, “l1_loss + tri_loss”]的变化,更直观一些。
from HarmonyDataset import Iharmony4Dataset就是读取数据的代码命名了,改成一样的就没问题
# 参数设计按照官网的默认调用方式修改,官网的训练方式为:
"""
python train.py --name --model bargainnet --dataset_mode iharmony4 --is_train 1 --norm batch --preprocess resize_and_crop --gpu_ids 0 --save_epoch_freq 1 --input_nc 20 --lr 1e-4 --beta1 0.9 --lr_policy step --lr_decay_iters 6574200 --netG s2ad
"""
G_net = UnetGenerator(20, 3, 8, 64, nn.BatchNorm2d, False, use_attention=True)
E_net = StyleEncoder(16, norm_layer=nn.BatchNorm2d)
if __name__ == "__main__":
from HarmonyDataset import Iharmony4Dataset
harmony_dataset = Iharmony4Dataset(dataset_root='/app/data/com/')
datalen = len(harmony_dataset)
model = BargainNetModel(E_net, G_net, gpu_ids=[])
EPOCH = 20
best_loss = 0.3 # best loss, default as 0.3
for epoch in range(EPOCH):
tqdm_bar = tqdm.tqdm(enumerate(harmony_dataset), total=datalen, desc='Epoch {}/{}'.format(epoch + 1, EPOCH))
epoch_l1, epoch_tri = 0, 0
for i, data in tqdm_bar:
model.set_input(data) # unpack data from a dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
epoch_l1 += model.loss_L1.item()
epoch_tri += model.loss_tri.item()
tqdm_bar.set_postfix(L1=epoch_l1 / (i + 1), tri=epoch_tri / (i + 1),
total=(epoch_l1 + epoch_tri) / (i + 1), best_loss=best_loss)
if best_loss > (epoch_l1 + epoch_tri) / datalen: # cache our latest model every iterations
print('the best model improve loss from {0} to {1}'.format(best_loss, (epoch_l1 + epoch_tri) / datalen))
best_loss = (epoch_l1 + epoch_tri) / datalen
# model save weights
torch.save(model.netG.state_dict(), 'best_netG.pth')
torch.save(model.netE.state_dict(), 'best_netE.pth')
# update learning rates at the end of every epoch.
for scheduler in model.schedulers:
scheduler.step()
# save the netG model complete
# x = torch.zeros(1, 20, 512, 512, dtype=torch.float, requires_grad=False)
# import hiddenlayer as h
# myNetGraph = h.build_graph(netG, x) # 建立网络模型图
# myNetGraph.save(path='./demoModel-G', format='pdf') # 保存网络模型图,可以设置 png 和 PDF 等.
else:
G_net.load_state_dict(torch.load('/app/checkpoints/best_net_G.pth'))
E_net.load_state_dict(torch.load('/app/checkpoints/best_net_E.pth'))
print('model load weights success')
预测的效果就这样,随便上百度找了个药品之后去背景,再粘上去。预测时把real换成comp即可。

具体项目还请看论文的GitHub实现:https://github.com/bcmi/BargainNet-Image-Harmonization
应该就这样了
对模型有疑问的建议去看论文问作者,我只是代码的搬运工
finish