Skip to content

Commit 7a89fae

Browse files
ploting loses to tensorboard
1 parent 1053517 commit 7a89fae

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

GAN/DCGAN/training.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def training(opt):
4949

5050
writer_fake = SummaryWriter(f"{str(log_dir)}/fake")
5151
writer_real = SummaryWriter(f"{str(log_dir)}/real")
52+
loss_writer = SummaryWriter(f"{str(log_dir)}/loss")
5253

5354
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
5455

@@ -76,10 +77,12 @@ def training(opt):
7677
criterion = torch.nn.BCELoss()
7778

7879
# ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
80+
D_loss_prev = math.inf
81+
G_loss_prev = math.inf
82+
D_loss = 0
83+
G_loss = 0
7984

8085
for epoch in range(EPOCHS):
81-
D_loss_prev = math.inf
82-
G_loss_prev = math.inf
8386

8487
for batch_idx, (real, _) in enumerate(tqdm(loader)):
8588
disc.train()
@@ -148,6 +151,10 @@ def training(opt):
148151
writer_real.add_image(
149152
"Mnist Real Images", img_grid_real, global_step=epoch
150153
)
154+
loss_writer.add_scalar(
155+
'discriminator', D_loss, global_step=epoch)
156+
loss_writer.add_scalar(
157+
'generator', G_loss, global_step=epoch)
151158

152159
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
153160
if opt.weights:

0 commit comments

Comments
 (0)