Skip to content

Commit b3f6c5a

Browse files
DCGAN initial commit
1 parent dff770c commit b3f6c5a

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

GAN/DCGAN/model.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch.nn as nn
2+
from torchsummary import summary
3+
4+
5+
class Discriminator(nn.Module):
6+
def __init__(self, img_channels, feature_d):
7+
super().__init__()
8+
self.net = nn.Sequential(
9+
nn.Conv2d(img_channels, feature_d,
10+
kernel_size=4, stride=2, padding=1),
11+
nn.LeakyReLU(0.2),
12+
*self.__block(feature_d, feature_d*2,
13+
kernel_size=4, stride=2, padding=1),
14+
*self.__block(feature_d*2, feature_d*4,
15+
kernel_size=4, stride=2, padding=1),
16+
*self.__block(feature_d*4, feature_d*8,
17+
kernel_size=4, stride=2, padding=1),
18+
nn.Conv2d(feature_d*8, 1,
19+
kernel_size=4, stride=2, padding=0),
20+
nn.Sigmoid()
21+
)
22+
self.initialize_weights()
23+
24+
def __block(self, in_channels, out_channels, **kwargs):
25+
return [nn.Conv2d(in_channels, out_channels, **kwargs, bias=False),
26+
nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2)]
27+
28+
def initialize_weights(self):
29+
for m in self.net.modules():
30+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
31+
nn.init.normal_(m.weight.data, 0, 0.02)
32+
33+
def forward(self, x):
34+
return self.net(x)
35+
36+
37+
class Faker(nn.Module):
38+
39+
def __init__(self, z_dim, img_channels, feature_d):
40+
super().__init__()
41+
self.net = nn.Sequential(
42+
*self.__block(z_dim, feature_d*8, kernel_size=4,
43+
stride=2, padding=0),
44+
*self.__block(feature_d*8, feature_d*4, kernel_size=4,
45+
stride=2, padding=1),
46+
*self.__block(feature_d*4, feature_d*2, kernel_size=4,
47+
stride=2, padding=1),
48+
*self.__block(feature_d*2, feature_d, kernel_size=4,
49+
stride=2, padding=1),
50+
nn.ConvTranspose2d(feature_d, img_channels,
51+
kernel_size=4, stride=2, padding=1),
52+
nn.Tanh()
53+
)
54+
self.initialize_weights()
55+
56+
def __block(self, in_channels, out_channels, **kwargs):
57+
58+
return [nn.ConvTranspose2d(in_channels, out_channels, **kwargs, bias=False),
59+
nn.BatchNorm2d(out_channels),
60+
nn.ReLU()]
61+
62+
def initialize_weights(self):
63+
for m in self.net.modules():
64+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
65+
nn.init.normal_(m.weight.data, 0, 0.02)
66+
67+
def forward(self, x):
68+
return self.net(x)
69+
70+
71+
if __name__ == "__main__":
72+
disc = Discriminator(3, 128)
73+
gen = Faker(100, 3, 128)
74+
75+
summary(disc, (3, 64, 64))
76+
77+
summary(gen, (100, 1, 1))

GAN/DCGAN/training.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import torchvision
3+
import torch.optim as optim
4+
from torchvision.transforms import transforms
5+
from torchvision.datasets import MNIST
6+
from torch.utils.data import DataLoader
7+
from model import Discriminator, Faker
8+
from torch.utils.tensorboard import SummaryWriter
9+
10+
11+
# ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
12+
EPOCHS = 20
13+
CHANNELS = 1
14+
H, W = 64, 64
15+
IMG_SIZE = CHANNELS * H * W
16+
lr = 2e-4
17+
work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18+
FEATURE_D = 128
19+
Z_DIM = 100
20+
GEN_TRAIN_STEPS = 5
21+
BATCH_SIZE = 128
22+
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
23+
24+
trans = transforms.Compose(
25+
[transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
26+
27+
MNIST_data = MNIST('./data', True, transform=trans, download=True)
28+
29+
loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)
30+
31+
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
32+
33+
writer_fake = SummaryWriter("logs/fake")
34+
writer_real = SummaryWriter("logs/real")
35+
36+
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
37+
38+
disc = Discriminator(img_channels=CHANNELS,
39+
feature_d=FEATURE_D).to(work_device)
40+
gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)
41+
42+
# ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #
43+
44+
disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
45+
gen_optim = optim.Adam(gen.parameters(), lr, (0.5, 0.999))
46+
criterion = torch.nn.BCELoss()
47+
48+
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
49+
50+
for epoch in range(EPOCHS):
51+
52+
for batch_idx, (real, _) in enumerate(loader):
53+
disc.train()
54+
gen.train()
55+
real = real.to(work_device)
56+
fixed_noise = torch.rand(real.shape[0], Z_DIM, H, W).to(work_device)
57+
# ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
58+
59+
fake = gen(fixed_noise) # dim of (N,1,28,28)
60+
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
61+
real_predict = disc(real).view(-1) # make it one dimensional array
62+
fake_predict = disc(fake).view(-1) # make it one dimensional array
63+
64+
labels = torch.cat([torch.ones_like(real_predict),
65+
torch.zeros_like(fake_predict)], dim=0)
66+
67+
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
68+
D_loss = criterion(
69+
torch.cat([real_predict, fake_predict], dim=0), labels)
70+
71+
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
72+
disc.zero_grad()
73+
D_loss.backward()
74+
disc_optim.step()
75+
76+
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
77+
for _ in range(GEN_TRAIN_STEPS):
78+
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
79+
fake = gen(fixed_noise).view(-1, CHANNELS,
80+
H, W) # dim of (N,1,32,32)
81+
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
82+
fake_predict = disc(fake).view(-1) # make it one dimensional array
83+
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
84+
85+
G_loss = criterion(fake_predict, torch.ones_like(fake_predict))
86+
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
87+
gen.zero_grad()
88+
G_loss.backward()
89+
gen_optim.step()
90+
91+
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
92+
93+
if batch_idx == 0:
94+
print(
95+
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
96+
Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}"
97+
)
98+
99+
with torch.no_grad():
100+
disc.eval()
101+
gen.eval()
102+
fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
103+
data = real.reshape(-1, CHANNELS, H, W)
104+
if BATCH_SIZE > 32:
105+
fake = fake[:32]
106+
data = data[:32]
107+
img_grid_fake = torchvision.utils.make_grid(
108+
fake, normalize=True)
109+
img_grid_real = torchvision.utils.make_grid(
110+
data, normalize=True)
111+
112+
writer_fake.add_image(
113+
"Mnist Fake Images", img_grid_fake, global_step=epoch
114+
)
115+
writer_real.add_image(
116+
"Mnist Real Images", img_grid_real, global_step=epoch
117+
)

0 commit comments

Comments
 (0)