10
10
from pathlib import Path
11
11
import math
12
12
from tqdm import tqdm
13
+ import shutil
13
14
14
15
15
16
def training (opt ):
@@ -24,9 +25,16 @@ def training(opt):
24
25
Z_DIM = 100
25
26
GEN_TRAIN_STEPS = 2
26
27
BATCH_SIZE = opt .batch_size
27
- Weight_dir = Path (f'{ opt .weights } ' ).resolve ()
28
- if not Weight_dir .exists ():
29
- Weight_dir .mkdir ()
28
+
29
+ if opt .logs :
30
+ log_dir = Path (f'{ opt .logs } ' ).resolve ()
31
+ if log_dir .exists ():
32
+ shutil .rmtree (str (log_dir ))
33
+
34
+ if opt .weights :
35
+ Weight_dir = Path (f'{ opt .weights } ' ).resolve ()
36
+ if not Weight_dir .exists ():
37
+ Weight_dir .mkdir ()
30
38
# ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #
31
39
32
40
trans = transforms .Compose (
@@ -39,8 +47,8 @@ def training(opt):
39
47
40
48
# ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #
41
49
42
- writer_fake = SummaryWriter ("logs /fake" )
43
- writer_real = SummaryWriter ("logs /real" )
50
+ writer_fake = SummaryWriter (f" { str ( log_dir ) } /fake" )
51
+ writer_real = SummaryWriter (f" { str ( log_dir ) } /real" )
44
52
45
53
# ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #
46
54
@@ -142,15 +150,16 @@ def training(opt):
142
150
)
143
151
144
152
# ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
145
- if D_loss_prev > D_loss :
146
- D_loss_prev = D_loss
147
- weight_path = str (Weight_dir / 'dirscriminator.pth' )
148
- torch .save (disc .state_dict (), weight_path )
153
+ if opt .weights :
154
+ if D_loss_prev > D_loss :
155
+ D_loss_prev = D_loss
156
+ weight_path = str (Weight_dir / 'dirscriminator.pth' )
157
+ torch .save (disc .state_dict (), weight_path )
149
158
150
- if G_loss_prev > G_loss :
151
- G_loss_prev = G_loss
152
- weight_path = str (Weight_dir / 'generator.pth' )
153
- torch .save (gen .state_dict (), weight_path )
159
+ if G_loss_prev > G_loss :
160
+ G_loss_prev = G_loss
161
+ weight_path = str (Weight_dir / 'generator.pth' )
162
+ torch .save (gen .state_dict (), weight_path )
154
163
155
164
156
165
if __name__ == "__main__" :
@@ -159,6 +168,8 @@ def training(opt):
159
168
# initial pre training weights
160
169
parser .add_argument ('--weights' , type = str ,
161
170
default = '' , help = 'save and load location of weights' )
171
+ parser .add_argument ('--logs' , type = str ,
172
+ default = '' , help = 'save log files to' )
162
173
parser .add_argument ("epochs" , type = int , default = 20 ,
163
174
help = 'number of epochs to train' )
164
175
parser .add_argument ('--batch-size' , type = int , default = 128 ,
0 commit comments