【生成式网络】入门篇(四):CycleGAN 的 代码和结果记录
创始人
2024-03-04 05:10:34
0

CycleGAN是一个里程碑式的工作,开启了unpaired的风格迁移的先河,斑马转马的效果还是很震惊。
具体原理可以参考 https://zhuanlan.zhihu.com/p/402819206

在这里插入图片描述
老习惯,直接上code,然后按照code进行一些解释
代码参考自 https://github.com/aitorzip/PyTorch-CycleGAN 相对比较简洁,我进行了一些小修改

import os
# os.chdir(os.path.dirname(__file__))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import argparse
from glob import glob
import random
import itertools## from https://github.com/aitorzip/PyTorch-CycleGANsample_dir = 'samples_cycle_gan'
if not os.path.exists(sample_dir):os.makedirs(sample_dir, exist_ok=True)writer = SummaryWriter(sample_dir)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)class ImageDataset(torch.utils.data.Dataset):def __init__(self, root, transforms=None, unaligned=False, mode='train'):self.transforms = transformsself.unaligned = unalignedself.files_A = sorted(glob(os.path.join(root, mode, 'A', '*.*')))self.files_B = sorted(glob(os.path.join(root, mode, 'B', '*.*')))def __getitem__(self, idx):img = Image.open(self.files_A[idx % len(self.files_A)]).convert('RGB')itemA = self.transforms(img)if self.unaligned:rand_idx = random.randint(0, len(self.files_B)-1)img = Image.open(self.files_B[rand_idx]).convert('RGB')itemB = self.transforms(img)else:img = Image.open(self.files_B[idx % len(self.files_B)]).convert('RGB')itemB = self.transforms(img)return {'A' : itemA,'B' : itemB}def __len__(self):return max(len(self.files_A), len(self.files_B))class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.conv_block = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features) )def forward(self, x):return x + self.conv_block(x)class Generator(nn.Module):def __init__(self, input_nc, output_nc, n_res_blocks=9):super(Generator, self).__init__()# init basic conv blockmodel = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc, 64, 7),nn.InstanceNorm2d(64),nn.ReLU(inplace=True)]# downsamplingin_features = 64out_features = in_features * 2for _ in range(2):model += [nn.Conv2d(in_features, out_features, 2, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True)]in_features = out_featuresout_features = in_features * 2# residual blocksfor _ in range(2):model += [ResidualBlock(in_features)]   # upsamplingout_features = in_features //2for _ in range(2):model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True)] in_features = out_featuresout_features = in_features //2# output layermodel += [nn.ReflectionPad2d(3),nn.Conv2d(64, output_nc, 11),nn.Tanh()]self.model = nn.Sequential(*model)def forward(self, x):return self.model(x)class Discriminator(nn.Module):def __init__(self, input_nc):super(Discriminator, self).__init__()# A bunch of convolutions one after anotherself.model = nn.Sequential(nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, 4, stride=2, padding=1),nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, stride=2, padding=1),nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, padding=1),nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, padding=1))def forward(self, x):x = self.model(x)# average pooling and flattenreturn F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)class ReplayBuffer():def __init__(self, max_size=50):assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'self.max_size = max_sizeself.data = []def push_and_pop(self, data):to_return = []for element in data.data:element = torch.unsqueeze(element, 0)if len(self.data) < self.max_size:self.data.append(element)to_return.append(element)else:if random.uniform(0,1) > 0.5:i = random.randint(0, self.max_size-1)to_return.append(self.data[i].clone())self.data[i] = elementelse:to_return.append(element)return torch.cat(to_return)class LambdaLR():def __init__(self, n_epochs, offset, decay_start_epoch):assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"self.n_epochs = n_epochsself.offset = offsetself.decay_start_epoch = decay_start_epochdef step(self, epoch):return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)def weights_init_normal(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm2d') != -1:torch.nn.init.normal(m.weight.data, 1.0, 0.02)torch.nn.init.constant(m.bias.data, 0.0)def denorm(x):out = (x+1)/2return out.clamp(0, 1)# Networks
input_nc = 3
output_nc = 3
learning_rate = 0.0002
n_epochs = 200
decay_epoch = 100
start_epoch = 0
batch_size = 16
input_size = 256
dataroot = 'data/cycle_gan/datasets/horse2zebra'netG_A2B = Generator(input_nc, output_nc).to(device)
netG_B2A = Generator(output_nc, input_nc).to(device)
netD_A = Discriminator(input_nc).to(device)
netD_B = Discriminator(output_nc).to(device)netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()# optimizer
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=learning_rate, betas=(0.5, 0.999))# lr schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step)# Inputs & targets memory allocation
target_real = torch.ones(batch_size, requires_grad=False).to(device)
target_fake = torch.zeros(batch_size, requires_grad=False).to(device)# Dataset loader
transforms_data = transforms.Compose([ transforms.Resize(int(input_size*1.12), Image.BICUBIC), transforms.RandomCrop(input_size), transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ])dataset = ImageDataset(dataroot, transforms=transforms_data, unaligned=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16, drop_last=True)fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()###### Training ######
cnt = 0
log_step = 10
for epoch in range(start_epoch, n_epochs):for i, batch in enumerate(dataloader):# set model inputreal_A = batch['A'].to(device)real_B = batch['B'].to(device)###### Generators ####### generators A2B and B2Aoptimizer_G.zero_grad()### identity loss# G_A2B(B) should equal B if real B is fedsame_B = netG_A2B(real_B)loss_identity_B = criterion_identity(same_B, real_B) * 5.0# G_B2A(A) should equal A if real A is fedsame_A = netG_B2A(real_A)loss_identity_A = criterion_identity(same_A, real_A) * 5.0    ### GAN lossfake_B = netG_A2B(real_A)pred_fake = netD_B(fake_B)loss_GAN_A2B = criterion_GAN(pred_fake, target_real)fake_A = netG_B2A(real_B)pred_fake = netD_A(fake_A)loss_GAN_B2A = criterion_GAN(pred_fake, target_real)### Cycle lossrecovered_A = netG_B2A(fake_B)loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0recovered_B = netG_A2B(fake_A)loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0# Total lossloss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BABloss_G.backward()optimizer_G.step()######################################### Discriminator A ######optimizer_D_A.zero_grad()# real losspred_real = netD_A(real_A)loss_D_real = criterion_GAN(pred_real, target_real)# fake lossfake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A)loss_D_fake = criterion_GAN(pred_fake, target_fake)# total lossloss_D_A = (loss_D_real + loss_D_fake) * 0.5loss_D_A.backward()optimizer_D_A.step()###### Discriminator B ######optimizer_D_B.zero_grad()# real losspred_real = netD_B(real_B)loss_D_real = criterion_GAN(pred_real, target_real)# fake lossfake_B = fake_B_buffer.push_and_pop(fake_B)pred_fake = netD_B(fake_B)loss_D_fake = criterion_GAN(pred_fake, target_fake)# total lossloss_D_B = (loss_D_real + loss_D_fake) * 0.5loss_D_B.backward()optimizer_D_B.step()cnt += 1if cnt % log_step == 0:print('Epoch [{}/{}], Step [{}], LossG: {:.4f}, loss_D_A: {:.4f}, loss_D_B: {:.4f}'.\format(epoch, n_epochs, cnt, loss_G.item(), loss_D_A.item(), loss_D_B.item()))writer.add_scalar('LossG', loss_G.item(), global_step=cnt)writer.add_scalar('loss_D_A', loss_D_A.item(), global_step=cnt)writer.add_scalar('loss_D_B', loss_D_B.item(), global_step=cnt)if cnt % 100 == 0:writer.add_images('real_A', denorm(real_A), global_step=cnt)writer.add_images('fake_A', denorm(fake_A), global_step=cnt)writer.add_images('recovered_A', denorm(recovered_A), global_step=cnt)writer.add_images('real_B', denorm(real_B), global_step=cnt)writer.add_images('fake_B', denorm(fake_B), global_step=cnt)writer.add_images('recovered_B', denorm(recovered_B), global_step=cnt)# Update learning rateslr_scheduler_G.step()lr_scheduler_D_A.step()lr_scheduler_D_B.step()# Save models checkpointstorch.save(netG_A2B.state_dict(), sample_dir + '/netG_A2B.pth')torch.save(netG_B2A.state_dict(), sample_dir + '/netG_B2A.pth')torch.save(netD_A.state_dict(), sample_dir + '/netD_A.pth')torch.save(netD_B.state_dict(), sample_dir + '/netD_B.pth')

我们来根据代码进行解读, 首先一个样本里是包含了A和B两张图,称为real_A 和 real_B。
定义了生成网络netG_A2B和 netG_B2A

先看Generators 部分

        ###### Generators ####### generators A2B and B2Aoptimizer_G.zero_grad()### identity loss# G_A2B(B) should equal B if real B is fedsame_B = netG_A2B(real_B)loss_identity_B = criterion_identity(same_B, real_B) * 5.0# G_B2A(A) should equal A if real A is fedsame_A = netG_B2A(real_A)loss_identity_A = criterion_identity(same_A, real_A) * 5.0    ### GAN lossfake_B = netG_A2B(real_A)pred_fake = netD_B(fake_B)loss_GAN_A2B = criterion_GAN(pred_fake, target_real)fake_A = netG_B2A(real_B)pred_fake = netD_A(fake_A)loss_GAN_B2A = criterion_GAN(pred_fake, target_real)### Cycle lossrecovered_A = netG_B2A(fake_B)loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0recovered_B = netG_A2B(fake_A)loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0# Total lossloss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BABloss_G.backward()optimizer_G.step()

生成网络包含三部分的loss

  • identity loss。netG_A2B 是把A风格图像转换为B风格,那么我们应该保证把B风格图像丢进去,出来的依然是B风格的原图,这部分loss就叫 identity loss,同理,对于netG_B2A也由此约束。
  • GAN loss。就是场景的generator的loss,对亮哥generator而言,生成的fake图像应该让他label误判为real 的label。
  • cycle loss。把A丢进netG_A2B,生成B风格图后,再丢进netG_B2A,理论上应该转换回A风格,这部分约束就是cycle loss,同理,对于netG_B2A也由此约束。

再看DiscriminatorA 部分, DiscriminatorB同理。
就是正常GAN里的Discriminator loss,应该把真的识别为真,假的识别为假。

        optimizer_D_A.zero_grad()# real losspred_real = netD_A(real_A)loss_D_real = criterion_GAN(pred_real, target_real)# fake lossfake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A)loss_D_fake = criterion_GAN(pred_fake, target_fake)# total lossloss_D_A = (loss_D_real + loss_D_fake) * 0.5loss_D_A.backward()optimizer_D_A.step()

可以再看看Generator的网络部分,
整体结构跟fast style transfer 非常像,也是先降采样,再residual,最后上采样,并且也用了ReflectionPad2d。
并且代码里用的是nn.InstanceNorm2d

Discriminator就没太多可说的了,几层卷积下来,变成一个batchsize * 1 * h * w 的tensor,最后用一个avg_pool2d得到batchsize * 1 的分类结果,没有用全连接层。

里面还需要提一下的,是用了一个ReplayBuffer机制,我的理解是在做分类的时候把fakeA和fakeB扔进buffer里,然后取出一个buffer里存的来,这样做分类的时候引入了别的batch里的数据,我猜测可能是为了避免discriminator能力集中在区分这种一对对的样本上,而是变得可以见到更多正负样本对。

不过也是因为这个机制,导致我训练的时候打印出的原图和fake图不是一一对应的,不方便看效果,不过这个很容易修改,我就偷懒了。

我们看效果 A是普通马,B是斑马
转换之后,这是变普遍马的效果
在这里插入图片描述

这是变斑马的效果

在这里插入图片描述
不算特别好,比文章的效果差远了,应该还有很多地方需要调优的,建议想要文章效果的童鞋试试官方代码 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

相关内容

热门资讯

党建擎旗三十载,广州律师以“四... 红色,始终是广州律师行业最鲜明的底色。 自1995年广州市律师协会党总支成立,广州律师行业党建已走过...
快手称遭黑灰产攻击出现涉黄内容... 极目新闻记者 郭奕 12月22日晚,不少网友反映快手直播间出现大量色情内容,包括播放淫秽影片、主播擦...
2025年“法规体检”发现纠正... 12月22日,全国人大常委会法工委关于2025年备案审查工作情况的报告提请全国人大常委会会议审议。在...
政策“扶上马” 青春“创未来” 青年逐梦正当时,创业建功新时代。南宁市将高校毕业生创业扶持作为促进高质量充分就业的关键抓手,以“真金...
准确把握高职资源下沉政策精髓 《教育强国建设规划纲要(2024—2035年)》明确提出“推动有条件地区将高等职业教育资源下沉到市县...
双阳法院“苔花工程”太平联合工... 广大农民朋友们,关于土地的相关法律知识,你知道多少?你知道土地的哪些合法权益是受法律保护的?倘若你也...
减负超600亿元,我国将推动长... 2025年全国长期护理保险高质量发展大会昨天在浙江宁波举行。记者从大会上了解到,我国自2016年开展...
《纽约时报》记者就AI训练侵犯... 12月23日消息,据路透社报道,一位因揭露硅谷血液检测初创公司Theranos的欺诈行为而闻名的调查...
亿达中国面临6.12亿元债务清... 观点网讯:12月22日,亿达中国发布内幕消息,涉及诉讼公告。 此前,亿达中国附属公司荣泰公司、郑州亿...
以技术化、制度化、法治化提升公... □江秋伟 党的二十届四中全会提出,推进国家安全体系和能力现代化,建设更高水平平安中国。近年来,美国等...