资讯专栏INFORMATION COLUMN

SinGAN一张照片即可生成同样的照片(附简化版代码)

you_De / 1394人阅读

摘要:摘要本文主要讲解一张照片即可生成一模一样的照片附简化版代码主要思路先由一个输入到的生成器得到生成图像这一步是单纯由噪声生成,其他生成器的输入都是由随机噪声图像和上一层生成的上采样到当前生成器尺寸组成。

1、摘要

本文主要讲解:SinGAN-一张照片即可生成一模一样的照片(附简化版代码)
主要思路:

  1. 先由一个Z_N输入到G_N的生成器得到生成图像(这一步是单纯由噪声生成,其他生成器的输入都是由随机噪声图像z_n和上一层生成的 上采样到当前生成器尺寸组成)。
  2. 接着利用生成图像的图像块(每一层图像块的大小不一样,按照由粗糙到精细、由大到小)和当前层的图像块(由训练数据下采样得到)放入判别器中进行判断,直到两者不能被判别器区分。
  3. 通过这种一层一层、由下往上的训练过程,得到最终的结果。

2、相关技术

SinGAN架构
一种基于层级的patch-GAN模型(Markovian discriminator)。如下图所示,模型的每个部分负责输入图像的不同尺度捕获图像块分布。这种层级GAN模型感受野小和有限的功能,可以防止网络记住整图的信息。虽然类似的网络结构被应用过,但这是首次应用在一张图像的内部学习上。

模型是由金字塔形式大小的生成器 组成,训练数据 也是金字塔形式大小组成,训练数据是由一个 因子控制,一些r>0。根据每层 的图像块分布,相应层的生成器 产生真实的图像实例。然后通过对抗学习,判别器 通过对生成器 产生的图像块(生成图像的某一部分)进行判别,达到相对较好的状态(以目前来说达不到最终的纳什均衡点),最后完成训练过程。

从刚刚的图中我们可以看到,每个尺度注入噪声后,先由粗糙的尺度开始生成图像,然后按照相应的顺序传递到相对应的生成器,最终生成精细的尺度;某一层的所有生成器和判别器有着相同的感受野,随着由下往上的生成过程,因此可以捕获尺度减小的结构信息。

3、完整代码和步骤

算法训练的效果如此视频:

SinGAN训练过程

主运行程序入口

import osfrom SinGAN.run_train import functionsfrom SinGAN.run_train.manipulate import SinGAN_generatefrom SinGAN.run_train.training import trainfrom SinGAN.run_train.config import get_argumentsif __name__ == "__main__":    parser = get_arguments()    parser.add_argument("--input_dir", help="input image dir", default="../Input/Images")    parser.add_argument("--input_name", help="input image name", default="food.jpg")    parser.add_argument("--mode", help="task to be done", default="train")    opt = parser.parse_args()    #    opt = functions.post_config(opt)    Gs = []    Zs = []    reals = []    NoiseAmp = []    dir2save = functions.generate_dir2save(opt)    if (os.path.exists(dir2save)):        print("trained model already exist")    else:        try:            os.makedirs(dir2save)        except OSError:            pass        # 将图片读取成torch版的数据        real = functions.read_image(opt)        # 将图片适配尺寸        functions.adjust_scales2image(real, opt)        # 开始训练模型 opt 手动输入的参数        train(opt, Gs, Zs, reals, NoiseAmp)        # 根据模型生成图片  生成具有任意大小和比例的新图像        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)

training.py

	import osimport torch.nn as nnimport torch.optim as optimimport torch.utils.dataimport mathimport matplotlib.pyplot as pltfrom SinGAN.run_train import functions, modelsfrom SinGAN.run_train.imresize import imresizedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train(opt, Gs, Zs, reals, NoiseAmp):    real_ = functions.read_image(opt)    in_s = 0    scale_num = 0    # 计算局部权重 调整大小    real = imresize(real_, opt.scale1, opt)    # 创造真实图片的锥体    reals = functions.creat_reals_pyramid(real, reals, opt)    nfc_prev = 0    # 全卷积GANs组成的金字塔    while scale_num < opt.stop_scale + 1:        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)        opt.out_ = functions.generate_dir2save(opt)        opt.outf = "%s/%d" % (opt.out_, scale_num)        try:            os.makedirs(opt.outf)        except OSError:            pass        plt.imsave("%s/in.png" % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)        plt.imsave("%s/original.png" % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)        plt.imsave("%s/real_scale.png" % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)        D_curr, G_curr = init_models(opt)        if (nfc_prev == opt.nfc):            # 加载训练好的模型            G_curr.load_state_dict(torch.load("%s/%d/netG.pth" % (opt.out_, scale_num - 1)))            D_curr.load_state_dict(torch.load("%s/%d/netD.pth" % (opt.out_, scale_num - 1)))        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)        # 是否固定部分参数进行网络训练        G_curr = functions.reset_grads(G_curr, False)        G_curr.eval()        D_curr = functions.reset_grads(D_curr, False)        D_curr.eval()        Gs.append(G_curr)        Zs.append(z_curr)        NoiseAmp.append(opt.noise_amp)        torch.save(Zs, "%s/Zs.pth" % (opt.out_))        torch.save(Gs, "%s/Gs.pth" % (opt.out_))        torch.save(reals, "%s/reals.pth" % (opt.out_))        torch.save(NoiseAmp, "%s/NoiseAmp.pth" % (opt.out_))        scale_num += 1        nfc_prev = opt.nfc        del D_curr, G_curr    returndef train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None):    real = reals[len(Gs)]    opt.nzx = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)    opt.nzy = real.shape[3]  # +(opt.ker_size-1)*(opt.num_layer)    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)    if opt.mode == "animation_train":        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)        pad_noise = 0    #     对Tensor使用0进行边界填充    m_noise = nn.ZeroPad2d(int(pad_noise))    m_image = nn.ZeroPad2d(int(pad_image))    alpha = opt.alpha    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)    # 返回一个大小为fill_value的张量    z_opt = torch.full(fixed_noise.shape, 0, device=device)    z_opt = m_noise(z_opt)    # setup optimizer    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))    # 按需调整学习率    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma)    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma)    errD2plot = []    errG2plot = []    D_real2plot = []    D_fake2plot = []    z_opt2plot = []    # 它是从噪声生成图像的    for epoch in range(opt.niter):        if (Gs == []) & (opt.mode != "SR_train"):            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))        else:            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)            noise_ = m_noise(noise_)        ############################        # (1) Update D network: maximize D(x) + D(G(z))        ###########################        # Dsteps "Discriminator inner steps",default=3        for j in range(opt.Dsteps):            # train with real            netD.zero_grad()            output = netD(real).to(device)            # D_real_map = output.detach()            errD_real = -output.mean()  # -a            errD_real.backward(retain_graph=True)            D_x = -errD_real.item()            # train with fake            if (j == 0) & (epoch == 0):                if (Gs == []) & (opt.mode != "SR_train"):                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)                    in_s = prev                    prev = m_image(prev)                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)                    z_prev = m_noise(z_prev)                    opt.noise_amp = 1                elif opt.mode == "SR_train":                    z_prev = in_s                    criterion = nn.MSELoss()                    RMSE = torch.sqrt(criterion(real, z_prev))                    opt.noise_amp = opt.noise_amp_init * RMSE                    z_prev = m_image(z_prev)                    prev = z_prev                else:                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, "rand", m_noise, m_image, opt)                    prev = m_image(prev)                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, "rec", m_noise, m_image, opt)                    criterion = nn.MSELoss()                    RMSE = torch.sqrt(criterion(real, z_prev))                    opt.noise_amp = opt.noise_amp_init * RMSE                    z_prev = m_image(z_prev)            else:                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, "rand", m_noise, m_image, opt)                prev = m_image(prev)            if opt.mode == "paint_train":                prev = functions.quant2centers(prev, centers)                plt.imsave("%s/prev.png" % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)            if (Gs == []) & (opt.mode != "SR_train"):                noise = noise_            else:                noise = opt.noise_amp * noise_ + prev            fake = netG(noise.detach(), prev)            output = netD(fake.detach())            errD_fake = output.mean()            errD_fake.backward(retain_graph=True)            D_G_z = output.mean().item()            gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, device)            gradient_penalty.backward()            errD = errD_real + errD_fake + gradient_penalty            optimizerD.step()        errD2plot.append(errD.detach())        ############################        # (2) Update G network: 最大化 D(G(z))        ###########################        for j in range(opt.Gsteps):            netG.zero_grad()            output = netD(fake)            D_fake_map = output.detach()            errG = -output.mean()            # errG.backward(retain_graph=True)            if alpha != 0:                loss = nn.MSELoss()                if opt.mode == "paint_train":                    z_prev = functions.quant2centers(z_prev, centers)                    plt.imsave("%s/z_prev.png" % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)                Z_opt = opt.noise_amp * z_opt + z_prev                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)                rec_loss.backward(retain_graph=True)                rec_loss = rec_loss.detach()            else:                Z_opt = z_opt                rec_loss = 0            optimizerG.step()        errG2plot.append(errG.detach() + rec_loss)        D_real2plot.append(D_x)        D_fake2plot.append(D_G_z)        z_opt2plot.append(rec_loss)        if epoch % 25 == 0 or epoch == (opt.niter - 1):            print("scale %d:[%d/%d]" % (len(Gs), epoch, opt.niter))        if epoch % 500 == 0 or epoch == (opt.niter - 1):            plt.imsave("%s/fake_sample.png" % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1)            plt.imsave("%s/G(z_opt).png" % (opt.outf),                       functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1)            # plt.imsave("%s/D_fake.png"   % (opt.outf), functions.convert_image_np(D_fake_map))            # plt.imsave("%s/D_real.png"   % (opt.outf), functions.convert_image_np(D_real_map))            # plt.imsave("%s/z_opt.png"    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)            # plt.imsave("%s/prev.png"     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)            # plt.imsave("%s/noise.png"    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)            # plt.imsave("%s/z_prev.png"   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)            torch.save(z_opt, "%s/z_opt.pth" % (opt.outf))        schedulerD.step()        schedulerG.step()    functions.save_networks(netG, netD, z_opt, opt)    return z_opt, in_s, netGdef draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):    G_z = in_s    if len(Gs) > 0:        if mode 
            
                     
             
               

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/122151.html

相关文章

  • AI新时代-大牛教你使用python+Opencv完成人脸解锁(源码)

    摘要:创建人脸集合,并加入里面填的就是刚才奥巴马图片返回的,不要填错哦面部集合的名字也要记住,待会查询的时候也会用到的。判断是否为管理员我们将奥巴马的面部值放入的里面,将要去查询的面部集合相册设置为我们前面获取的,然后我们随便照一张照片匹配一下。 好吧,伙计们,我回来了。说我拖更不写文章的可以过来用你的小拳拳狠命地捶我胸口.... 那么今天我们来讲关于使用python+opencv+face...

    LinkedME2016 评论0 收藏0
  • ❤️Python实用工具之制作证件照(有界面、源码、赞关藏)❤️

    镇楼图 文章目录 一点想说的必要配置removebg配置安装对应的 Python 库获取API Key 无界面版修改图片背景色修改图片大小 升级版设置窗口主题与标题添加选择图片路径组件添加选择保存路径组件添加选择背景颜色组件添加填写图片尺寸组件添加填写API-KEY组件添加控制按钮添加输出框显示设置组件位置设置组件样式 源码打包❤️源码获取方式❤️ 一点想说的 想当年我不会...

    Doyle 评论0 收藏0
  • 神还原女神照片!GAN为百年旧照上色

    摘要:如何把女神的黑白照片变成彩照今日小编发现新加坡数据科学与人工智能部门在上介绍了一个为百年旧照上色的项目。照片为新加坡华人女子学校,摄于年期间。来自新加坡国家档案馆的原始照片左和上色后的照片右。利用给年的汤加太平洋岛国旧照上色。 一键点击,百年旧照变彩色。如何把女神的黑白照片变成彩照?今日小编发现新加坡 GovTech 数据科学与人工智能部门在 Medium 上介绍了一个为百年旧照上色的项目。...

    gaomysion 评论0 收藏0
  • opencv+mtcnn+facenet+python+tensorflow 实现实时人脸识别(20

    摘要:实现实时人脸识别更新新增测试方法直接使用特征进行计算对比此次更新主要想法上一个版本是使用对准备好的若干张照片进行训练,首先准确率不是很高还没细究问题,猜测原因是自己准备的图片问题,以及实时采集实时的环境影响,但最主要的原因还是对每个目标对象 opencv+mtcnn+facenet+python+tensorflow 实现实时人脸识别(2018.9.26更新) 新增测试方法直接使用em...

    jindong 评论0 收藏0
  • 精选:3个可以下载免费高质量照片网站

    摘要:通过这些网站可以改善你的设计项目,这些是网上提供免费的高质量图片最好的网站。我们挑选了三个提供高质量免费图片的最佳网站,您可以随意下载和使用。我们确保所有发布的图片都是高质量的,并根据许可证进行许可。  在您的照片库中加入成千上万张高品质的照片,涵盖各种主题和风格!下面列出的网站提供可用于任何项目的图像,没有限制。您不必担心因为一张小图片的版权而导致麻烦。通过这些网站可以改善你的设计项目,这...

    crossea 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<