import os
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
# Initial_setting
workers = 2
batch_size=64
nz = 100
nch_g = 64
nch_d = 64
n_epoch = 20
lr = 0.0002
beta1 = 0.5
outf = './result_lsgan'
display_interval = 100
save_fake_image_interval = 1500
plt.rcParams['figure.figsize'] = 10, 6
try:
os.makedirs(outf, exist_ok=True)
except OSError as error:
print(error)
pass
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self, nz=100, nch_g=64, nch=3):
super(Generator, self).__init__()
self.layers = nn.ModuleDict({
'layer0': nn.Sequential(
nn.ConvTranspose2d(nz, nch_g * 8, 4, 1, 0),
nn.BatchNorm2d(nch_g * 8),
nn.ReLU()
), # (100, 1, 1) -> (512, 4, 4)
'layer1': nn.Sequential(
nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1),
nn.BatchNorm2d(nch_g * 4),
nn.ReLU()
), # (512, 4, 4) -> (256, 8, 8)
'layer2': nn.Sequential(
nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1),
nn.BatchNorm2d(nch_g * 2),
nn.ReLU()
), # (256, 8, 8) -> (128, 16, 16)
'layer3': nn.Sequential(
nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
nn.BatchNorm2d(nch_g),
nn.ReLU()
), # (128, 16, 16) -> (64, 32, 32)
'layer4': nn.Sequential(
nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
nn.Tanh()
) # (64, 32, 32) -> (3, 64, 64)
})
def forward(self, z):
for layer in self.layers.values():
z = layer(z)
return z
class Discriminator(nn.Module):
def __init__(self, nch=3, nch_d=64):
super(Discriminator, self).__init__()
self.layers = nn.ModuleDict({
'layer0': nn.Sequential(
nn.Conv2d(nch, nch_d, 4, 2, 1),
nn.LeakyReLU(negative_slope=0.2)
), # (3, 64, 64) -> (64, 32, 32)
'layer1': nn.Sequential(
nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
nn.BatchNorm2d(nch_d * 2),
nn.LeakyReLU(negative_slope=0.2)
), # (64, 32, 32) -> (128, 16, 16)
'layer2': nn.Sequential(
nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
nn.BatchNorm2d(nch_d * 4),
nn.LeakyReLU(negative_slope=0.2)
), # (128, 16, 16) -> (256, 8, 8)
'layer3': nn.Sequential(
nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1),
nn.BatchNorm2d(nch_d * 8),
nn.LeakyReLU(negative_slope=0.2)
), # (256, 8, 8) -> (512, 4, 4)
'layer4': nn.Conv2d(nch_d * 8, 1, 4, 1, 0)
# (512, 4, 4) -> (1, 1, 1)
})
def forward(self, x):
for layer in self.layers.values():
x = layer(x)
return x.squeeze()
def main():
dataset = dset.ImageFolder(root='./celeba',
transform=transforms.Compose([
transforms.RandomResizedCrop(64, scale=(0.9, 1.0), ratio=(1., 1.)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=int(workers))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)
netG = Generator(nz=nz, nch_g=nch_g).to(device)
netG.apply(weights_init)
print(netG)
netD = Discriminator(nch_d=nch_d).to(device)
netD.apply(weights_init)
print(netD)
criterion = nn.MSELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device) # save_fake_image用ノイズ(固定)
Loss_D_list, Loss_G_list = [], []
save_fake_image_count = 1
##### trainig_loop
for epoch in range(n_epoch):
for itr, data in enumerate(dataloader):
real_image = data[0].to(device) # 本物画像
sample_size = real_image.size(0) # 画像枚数
noise = torch.randn(sample_size, nz, 1, 1, device=device) # 入力ベクトル生成(正規分布ノイズ)
real_target = torch.full((sample_size,), 1., device=device) # 目標値(本物)
fake_target = torch.full((sample_size,), 0., device=device) # 目標値(偽物)
#-------- Update Discriminator ---------
netD.zero_grad() # 勾配の初期化
output = netD(real_image) # Discriminatorが行った、本物画像の判定結果
errD_real = criterion(output, real_target) # 本物画像の判定結果と目標値(本物)の二乗誤差
D_x = output.mean().item() # outputの平均 D_x を計算(後でログ出力に使用)
fake_image = netG(noise) # Generatorが生成した偽物画像
output = netD(fake_image.detach()) # Discriminatorが行った、偽物画像の判定結果
errD_fake = criterion(output, fake_target) # 偽物画像の判定結果と目標値(偽物)の二乗誤差
D_G_z1 = output.mean().item() # outputの平均 D_G_z1 を計算(後でログ出力に使用)
errD = errD_real + errD_fake # Discriminator 全体の損失
errD.backward() # 誤差逆伝播
optimizerD.step() # Discriminatoeのパラメーター更新
#--------- Update Generator ----------
netG.zero_grad() # 勾配の初期化
output = netD(fake_image) # 更新した Discriminatorで、偽物画像を判定
errG = criterion(output, real_target) # 偽物画像の判定結果と目標値(本物)の二乗誤差
errG.backward() # 誤差逆伝播
D_G_z2 = output.mean().item() # outputの平均 D_G_z2 を計算(後でログ出力に使用)
optimizerG.step() # Generatorのパラメータ更新
if itr % display_interval == 0:
print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
.format(epoch + 1, n_epoch,
itr + 1, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
Loss_D_list.append(errD.item())
Loss_G_list.append(errG.item())
if epoch == 0 and itr == 0:
vutils.save_image(real_image, '{}/real_samples.png'.format(outf),
normalize=True, nrow=8)
if itr % save_fake_image_interval == 0 and itr > 0:
fake_image = netG(fixed_noise)
vutils.save_image(fake_image.detach(), '{}/fake_samples_{:03d}.png'.format(outf, save_fake_image_count),
normalize=True, nrow=8)
save_fake_image_count +=1
# --------- save fake image ----------
fake_image = netG(fixed_noise)
vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(outf, epoch + 1),
normalize=True, nrow=8)
# --------- save model -----------
if (epoch + 1) % 10 == 0: # 10エポックごとにモデルを保存する
torch.save(netG.state_dict(), '{}/netG_epoch_{}.pth'.format(outf, epoch + 1))
torch.save(netD.state_dict(), '{}/netD_epoch_{}.pth'.format(outf, epoch + 1))
# plot graph
plt.figure()
plt.plot(range(len(Loss_D_list)), Loss_D_list, color='blue', linestyle='-', label='Loss_D')
plt.plot(range(len(Loss_G_list)), Loss_G_list, color='red', linestyle='-', label='Loss_G')
plt.legend()
plt.xlabel('iter (*100)')
plt.ylabel('loss')
plt.title('Loss_D and Loss_G')
plt.grid()
plt.savefig('Loss_graph.png')
if __name__ == '__main__':
main()
コメントを残す