[深度学习初识 您所在的位置:网站首页 人脸动漫头像生成器下载 [深度学习初识

[深度学习初识

2024-06-24 10:36| 来源: 网络整理| 查看: 265

生成对抗神经网络 1. 流派:GAN、FLOW流、VAE、pixeLCM、pixeLRM 2.GAN(无监督)

(1)判别网络 用神经网络充当loss。即判别器获取输入图片(随机分布生成)和真实图片对比,利用反向传播算法使随机生成图片逼近真实图片。 (2)生成网络 第一次固定权重,按随机分布生成图片,放入判别器。生成的图片服从分布。 (3)训练目标 判别器训练:判别生成图片和真实图片的真假越来越准确。 生成器训练:生成的图片,使得判别器判别不出真假。 判别器和生成器对抗式训练。 判别器和生成器一开始不能够太聪明,否则太笨的一边得不到训练。 (4)损失交替迭代优化

3.常用模型

(1)按模式划分 Conditional GAN、 Cycle GAN、 InfoGan、 LAPGAN。 (2)按高精度划分 BEGAN、MSGGAN、PGGAN、StyleGAN、BlgGAN、BigbiGAN。 (3)按稳定性划分 WGAN、WGAN-GP、SGAN、LSGAN、RGAN。

4. DCGAN(反卷积)

pytorch - 动漫人脸生成项目

(1) 数据集(动漫人脸数据集 - 网上有开源数据 这里是96963的)

from torch.utils.data import Dataset import cv2 import os import numpy as np class FaceMyData(Dataset): def __init__(self,root): super().__init__() self.root = root self.dataset = os.listdir(root)#返回指定路径下的文件和文件夹列表。 def __len__(self): return len(self.dataset) def __getitem__(self, index): pic_name = self.dataset[index] img_data = cv2.imread(f"{self.root}/{pic_name}") img_data = img_data[...,::-1] # BGR - RGB img_data = img_data.transpose([2,0,1]) # 96*96*3 - 3*96*96 img_data = ((img_data / 255. -0.5)*2).astype(np.float32) #归一化 #print(img_data.dtype) return img_data

(2)网络以及损失 判别网络:96*96的图片 输出 1 (即真/假)

def __init__(self): super().__init__() # 96*96 self.sequential = nn.Sequential( nn.Conv2d(3,64,5,3,padding = 1, bias = False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2,inplace = True), nn.Conv2d(64,128,4,2,padding = 1, bias = False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2,inplace = True), nn.Conv2d(128,256,4,2,padding = 1, bias = False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2,inplace = True), nn.Conv2d(256,512,4,2,padding = 1, bias = False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2,inplace = True), nn.Conv2d(512,1,4,1,padding = 0, bias = False), nn.Sigmoid() ) def forward(self,img): h = self.sequential_2(img) print(h.shape) return h.reshape(-1)

生成器:输出一个随机分布生成的图片

class GNet(nn.Module): def __init__(self): super().__init__() #输入是一个随机噪声 [NCHW] batch 128 1 1 -> batcg 512 4 4 self.sequential = nn.Sequential( nn.ConvTranspose2d(128,512,kernel_size = 7,stride = 1, padding = 0, bias = False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2,inplace = True), nn.ConvTranspose2d(512,256,kernel_size = 4,stride = 2, padding = 1, bias = False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2,inplace = True), nn.ConvTranspose2d(256,128,kernel_size = 4,stride = 2, padding = 1, bias = False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2,inplace = True), nn.ConvTranspose2d(128,64,kernel_size = 5,stride = 3, padding = 1, bias = False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2,inplace = True), nn.ConvTranspose2d(64,3,kernel_size = 7,stride = 3, padding = 0, bias = False), nn.Tanh() ) def forward(self,noise): return self.sequential(noise)

损失:

class DCGAN(nn.Module): def __init__(self): super().__init__() self.dnet = DNet() self.gnet = GNet() self.loss_fn = nn.BCEWithLogitsLoss()#二值交叉熵 def forward(self,noise): return self.gnet(noise) # 判别器的损失 : 判别器得到的标签与原标签的损失 def get_D_loss(self,noise_d,real_img): real_y = self.dnet(real_img)#真照片判别 g_img = self.gnet(noise_d)#生成假照片 fake_y = self.dnet(g_img)#假照片判别 real_tag = torch.ones(real_img.size(0)).cuda() fake_tag = torch.zeros(noise_d.size(0)).cuda() print(real_y.shape,real_tag.shape) loss_real = self.loss_fn(real_y, real_tag) loss_fake = self.loss_fn(fake_y, fake_tag) loss_d = loss_fake + loss_real return loss_d #生成器的损失 : 训练生成器假定生成的都是正确的 即得到的标签与正确标签的损失 def get_G_loss(self,noise_g): _g_img = self.gnet(noise_g)#生成照片 _real_y = self.dnet(_g_img)#照片判别 _real_tag = torch.ones(_g_img.size(0)).cuda() loss_g = self.loss_fn(_real_y, _real_tag) return loss_g

(3)训练 技巧:预处理环节,将图像scale到tanh的[-1,1];所有的参数初始化有(0,0.02)的正态分布中随机得到;LeakyReLU的斜率是0.2(默认),不要用池化等会造成信息丢失的操作;优化器使用调好超参的Adam optimizer,lr 为0.0002;将动量参数beta从0.9降为0.5,防止震荡和不稳定。

self.d_opt = optim.Adam(self.net.dnet.parameters(),0.0002, betas = (0.5,0.9)) self.g_opt = optim.Adam(self.net.gnet.parameters(),0.0002, betas = (0.5,0.9))

(4)训练结果 每一轮次数据训练完,生成一组图片,保存。 可以看到随着训练轮次增加,可以看清动漫人脸。(由于电脑不是很好,网络构建也不够完善,生成出来的效果一般。)

noise = torch.normal(0,0.1,(8,128,1,1)).cuda() y = self.net(noise) utils.save_image(y,f"img_face/{epoch}.jpg",normalize=True,range=(-1,1))

在这里插入图片描述



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有