Skip to content

Commit 8eff864

Browse files
added functinality to save weights,
based on the loss. removed the reshape function as the output of the model is, as per the requirment.
1 parent b3f6c5a commit 8eff864

File tree

1 file changed

+107
-87
lines changed

1 file changed

+107
-87
lines changed

GAN/DCGAN/training.py

Lines changed: 107 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,132 @@
66
from torch.utils.data import DataLoader
77
from model import Discriminator, Faker
88
from torch.utils.tensorboard import SummaryWriter
9+
from pathlib import Path
910

1011

11-
# ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
12-
EPOCHS = 20
13-
CHANNELS = 1
14-
H, W = 64, 64
15-
IMG_SIZE = CHANNELS * H * W
16-
lr = 2e-4
17-
work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18-
FEATURE_D = 128
19-
Z_DIM = 100
20-
GEN_TRAIN_STEPS = 5
21-
BATCH_SIZE = 128
22-
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
12+
if __name__ == "__main__":
2313

24-
trans = transforms.Compose(
25-
[transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
14+
# ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
15+
EPOCHS = 20
16+
CHANNELS = 1
17+
H, W = 64, 64
18+
IMG_SIZE = CHANNELS * H * W
19+
lr = 2e-4
20+
work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21+
FEATURE_D = 128
22+
Z_DIM = 100
23+
GEN_TRAIN_STEPS = 5
24+
BATCH_SIZE = 128
25+
Weight_dir = Path('/weights').resolve()
26+
if not Weight_dir.exists():
27+
Weight_dir.mkdir()
28+
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
2629

27-
MNIST_data = MNIST('./data', True, transform=trans, download=True)
30+
trans = transforms.Compose(
31+
[transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
2832

29-
loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)
33+
MNIST_data = MNIST('./data', True, transform=trans, download=True)
3034

31-
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
35+
loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)
3236

33-
writer_fake = SummaryWriter("logs/fake")
34-
writer_real = SummaryWriter("logs/real")
37+
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
3538

36-
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
39+
writer_fake = SummaryWriter("logs/fake")
40+
writer_real = SummaryWriter("logs/real")
3741

38-
disc = Discriminator(img_channels=CHANNELS,
39-
feature_d=FEATURE_D).to(work_device)
40-
gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)
42+
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
4143

42-
# ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #
44+
disc = Discriminator(img_channels=CHANNELS,
45+
feature_d=FEATURE_D).to(work_device)
46+
gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)
4347

44-
disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
45-
gen_optim = optim.Adam(gen.parameters(), lr, (0.5, 0.999))
46-
criterion = torch.nn.BCELoss()
48+
# ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #
4749

48-
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
50+
disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
51+
gen_optim = optim.Adam(gen.parameters(), lr, (0.5, 0.999))
52+
criterion = torch.nn.BCELoss()
4953

50-
for epoch in range(EPOCHS):
54+
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
5155

52-
for batch_idx, (real, _) in enumerate(loader):
53-
disc.train()
54-
gen.train()
55-
real = real.to(work_device)
56-
fixed_noise = torch.rand(real.shape[0], Z_DIM, H, W).to(work_device)
57-
# ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
56+
for epoch in range(EPOCHS):
57+
D_loss_prev = inf
58+
G_loss_prev = inf
5859

59-
fake = gen(fixed_noise) # dim of (N,1,28,28)
60-
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
61-
real_predict = disc(real).view(-1) # make it one dimensional array
62-
fake_predict = disc(fake).view(-1) # make it one dimensional array
60+
for batch_idx, (real, _) in enumerate(loader):
61+
disc.train()
62+
gen.train()
63+
real = real.to(work_device)
64+
fixed_noise = torch.rand(
65+
real.shape[0], Z_DIM, 1, 1).to(work_device)
66+
# ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #
6367

64-
labels = torch.cat([torch.ones_like(real_predict),
65-
torch.zeros_like(fake_predict)], dim=0)
66-
67-
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
68-
D_loss = criterion(
69-
torch.cat([real_predict, fake_predict], dim=0), labels)
70-
71-
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
72-
disc.zero_grad()
73-
D_loss.backward()
74-
disc_optim.step()
75-
76-
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
77-
for _ in range(GEN_TRAIN_STEPS):
78-
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
79-
fake = gen(fixed_noise).view(-1, CHANNELS,
80-
H, W) # dim of (N,1,32,32)
68+
fake = gen(fixed_noise) # dim of (N,1,28,28)
8169
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
70+
real_predict = disc(real).view(-1) # make it one dimensional array
8271
fake_predict = disc(fake).view(-1) # make it one dimensional array
72+
73+
labels = torch.cat([torch.ones_like(real_predict),
74+
torch.zeros_like(fake_predict)], dim=0)
75+
8376
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
77+
D_loss = criterion(
78+
torch.cat([real_predict, fake_predict], dim=0), labels)
8479

85-
G_loss = criterion(fake_predict, torch.ones_like(fake_predict))
8680
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
87-
gen.zero_grad()
88-
G_loss.backward()
89-
gen_optim.step()
90-
91-
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
92-
93-
if batch_idx == 0:
94-
print(
95-
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
96-
Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}"
97-
)
98-
99-
with torch.no_grad():
100-
disc.eval()
101-
gen.eval()
102-
fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
103-
data = real.reshape(-1, CHANNELS, H, W)
104-
if BATCH_SIZE > 32:
105-
fake = fake[:32]
106-
data = data[:32]
107-
img_grid_fake = torchvision.utils.make_grid(
108-
fake, normalize=True)
109-
img_grid_real = torchvision.utils.make_grid(
110-
data, normalize=True)
111-
112-
writer_fake.add_image(
113-
"Mnist Fake Images", img_grid_fake, global_step=epoch
114-
)
115-
writer_real.add_image(
116-
"Mnist Real Images", img_grid_real, global_step=epoch
81+
disc.zero_grad()
82+
D_loss.backward()
83+
disc_optim.step()
84+
85+
# ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
86+
for _ in range(GEN_TRAIN_STEPS):
87+
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
88+
fake = gen(fixed_noise)
89+
# ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
90+
# make it one dimensional array
91+
fake_predict = disc(fake).view(-1)
92+
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
93+
94+
G_loss = criterion(fake_predict, torch.ones_like(fake_predict))
95+
# ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
96+
gen.zero_grad()
97+
G_loss.backward()
98+
gen_optim.step()
99+
100+
# ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #
101+
102+
if batch_idx == 0:
103+
print(
104+
f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
105+
Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}"
117106
)
107+
108+
with torch.no_grad():
109+
disc.eval()
110+
gen.eval()
111+
fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
112+
data = real.reshape(-1, CHANNELS, H, W)
113+
if BATCH_SIZE > 32:
114+
fake = fake[:32]
115+
data = data[:32]
116+
img_grid_fake = torchvision.utils.make_grid(
117+
fake, normalize=True)
118+
img_grid_real = torchvision.utils.make_grid(
119+
data, normalize=True)
120+
121+
writer_fake.add_image(
122+
"Mnist Fake Images", img_grid_fake, global_step=epoch
123+
)
124+
writer_real.add_image(
125+
"Mnist Real Images", img_grid_real, global_step=epoch
126+
)
127+
128+
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
129+
if D_loss_prev > D_loss:
130+
D_loss_prev = D_loss
131+
weight_path = str(Weight_dir/'dirscriminator.pth')
132+
torch.save(disc.state_dict(), weight_path)
133+
134+
if G_loss_prev > G_loss:
135+
G_loss_prev = G_loss
136+
weight_path = str(Weight_dir/'generator.pth')
137+
torch.save(gen.state_dict(), weight_path)

0 commit comments

Comments
 (0)