Skip to content

Commit 28845fd

Browse files
added option to save logs
1 parent 4b95d27 commit 28845fd

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

GAN/DCGAN/training.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
import math
1212
from tqdm import tqdm
13+
import shutil
1314

1415

1516
def training(opt):
@@ -24,9 +25,16 @@ def training(opt):
2425
Z_DIM = 100
2526
GEN_TRAIN_STEPS = 2
2627
BATCH_SIZE = opt.batch_size
27-
Weight_dir = Path(f'{opt.weights}').resolve()
28-
if not Weight_dir.exists():
29-
Weight_dir.mkdir()
28+
29+
if opt.logs:
30+
log_dir = Path(f'{opt.logs}').resolve()
31+
if log_dir.exists():
32+
shutil.rmtree(str(log_dir))
33+
34+
if opt.weights:
35+
Weight_dir = Path(f'{opt.weights}').resolve()
36+
if not Weight_dir.exists():
37+
Weight_dir.mkdir()
3038
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
3139

3240
trans = transforms.Compose(
@@ -39,8 +47,8 @@ def training(opt):
3947

4048
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
4149

42-
writer_fake = SummaryWriter("logs/fake")
43-
writer_real = SummaryWriter("logs/real")
50+
writer_fake = SummaryWriter(f"{str(log_dir)}/fake")
51+
writer_real = SummaryWriter(f"{str(log_dir)}/real")
4452

4553
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
4654

@@ -142,15 +150,16 @@ def training(opt):
142150
)
143151

144152
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
145-
if D_loss_prev > D_loss:
146-
D_loss_prev = D_loss
147-
weight_path = str(Weight_dir/'dirscriminator.pth')
148-
torch.save(disc.state_dict(), weight_path)
153+
if opt.weights:
154+
if D_loss_prev > D_loss:
155+
D_loss_prev = D_loss
156+
weight_path = str(Weight_dir/'dirscriminator.pth')
157+
torch.save(disc.state_dict(), weight_path)
149158

150-
if G_loss_prev > G_loss:
151-
G_loss_prev = G_loss
152-
weight_path = str(Weight_dir/'generator.pth')
153-
torch.save(gen.state_dict(), weight_path)
159+
if G_loss_prev > G_loss:
160+
G_loss_prev = G_loss
161+
weight_path = str(Weight_dir/'generator.pth')
162+
torch.save(gen.state_dict(), weight_path)
154163

155164

156165
if __name__ == "__main__":
@@ -159,6 +168,8 @@ def training(opt):
159168
# initial pre training weights
160169
parser.add_argument('--weights', type=str,
161170
default='', help='save and load location of weights')
171+
parser.add_argument('--logs', type=str,
172+
default='', help='save log files to')
162173
parser.add_argument("epochs", type=int, default=20,
163174
help='number of epochs to train')
164175
parser.add_argument('--batch-size', type=int, default=128,

0 commit comments

Comments
 (0)