GAN学习

茴香豆 Lv5

由于需要阅读的论文中用到了GAN,所以对GAN进行简单的学习。

学习资源:

  1. 李宏毅对抗生成网络(GAN)国语教程(2018)
  2. 百度AI Studio课程
  3. 生成对抗网络GAN开山之作论文精读

1.Introduction

GAN动物园:https://github.com/hindupuravinash/the-gan-zoo 存放了GAN的全部种类。

GAN交互式可视化:https://poloclub.github.io/ganlab/ 可视化GAN工作过程。

OpenMMLab开源GAN算法库MMGeneration https://github.com/open-mmlab/mmgeneration

2.GAN简洁实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File : test_gan.py
# Author : none <none>
# Date : 14.04.2022
# Last Modified Date: 15.04.2022
# Last Modified By : none <none>
""" 基于MNIST 实现对抗生成网络 (GAN) """

import torch
import torchvision
import torch.nn as nn
import numpy as np

image_size = [1, 28, 28]
#随机变量的初始维度可以是任意的,这里设置为96
latent_dim = 96
batch_size = 64
use_gpu = torch.cuda.is_available()

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
torch.nn.BatchNorm1d(128),
torch.nn.GELU(),

nn.Linear(128, 256),
torch.nn.BatchNorm1d(256),
torch.nn.GELU(),
nn.Linear(256, 512),
torch.nn.BatchNorm1d(512),
torch.nn.GELU(),
nn.Linear(512, 1024),
torch.nn.BatchNorm1d(1024),
torch.nn.GELU(),
nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
# nn.Tanh(),
nn.Sigmoid(),
)
def forward(self, z):
# shape of z: [batchsize, latent_dim]

output = self.model(z)
image = output.reshape(z.shape[0], *image_size)

return image


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(np.prod(image_size, dtype=np.int32), 512),
torch.nn.GELU(),
nn.Linear(512, 256),
torch.nn.GELU(),
nn.Linear(256, 128),
torch.nn.GELU(),
nn.Linear(128, 64),
torch.nn.GELU(),
nn.Linear(64, 32),
torch.nn.GELU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)

def forward(self, image):
# shape of image: [batchsize, 1, 28, 28]

prob = self.model(image.reshape(image.shape[0], -1))

return prob

# Training
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,
transform=torchvision.transforms.Compose( #
[
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
#torchvision.transforms.Normalize([0.5], [0.5]),
]
)
)
#将样本构成mini_batch,用于后续训练
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

generator = Generator()
discriminator = Discriminator()

#分别对生成器判别器参数进行优化
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
#使用BCE损失函数(二分类交叉熵损失)
loss_fn = nn.BCELoss()
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)

if use_gpu:
print("use gpu for training")
generator = generator.cuda()
discriminator = discriminator.cuda()
loss_fn = loss_fn.cuda()
labels_one = labels_one.to("cuda")
labels_zero = labels_zero.to("cuda")

num_epoch = 200
for epoch in range(num_epoch):
for i, mini_batch in enumerate(dataloader):
gt_images, _ = mini_batch


z = torch.randn(batch_size, latent_dim)

if use_gpu:
gt_images = gt_images.to("cuda")
z = z.to("cuda")

pred_images = generator(z)
g_optimizer.zero_grad()

recons_loss = torch.abs(pred_images-gt_images).mean()

g_loss = recons_loss*0.05 + loss_fn(discriminator(pred_images), labels_one)

g_loss.backward() # 求后向传播梯度
g_optimizer.step() # 对生成器参数进行优化

d_optimizer.zero_grad()

real_loss = loss_fn(discriminator(gt_images), labels_one)
fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
d_loss = (real_loss + fake_loss)

# 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了

d_loss.backward()
d_optimizer.step()

if i % 50 == 0:
print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")

if i % 200 == 0:
image = pred_images[:16].data
torchvision.utils.save_image(image, f"image_{len(dataloader)*epoch+i}.png", nrow=4)


  • Title: GAN学习
  • Author: 茴香豆
  • Created at : 2022-10-24 10:06:52
  • Updated at : 2022-10-29 16:21:46
  • Link: https://hxiangdou.github.io/2022/10/24/GAN-learning/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments
On this page
GAN学习