Skip to content

Commit 44ff025

Browse files
StephanieLarocquenotoraptor
authored andcommitted
saving stuff
1 parent e3a4ebf commit 44ff025

File tree

3 files changed

+21
-25
lines changed

3 files changed

+21
-25
lines changed

code/cnn_1D_segm/train_fcn1D.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,24 +337,18 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
337337
print('saving best (and last) model')
338338
best_jacc_val = jacc_valid[epoch]
339339
patience = 0
340-
np.savez(os.path.join(savepath, 'new_fcn1D_model_best.npz'),
341-
*lasagne.layers.get_all_param_values(simple_net_output))
340+
np.savez(os.path.join(savepath, 'new_fcn1D_model_best.npz'), *lasagne.layers.get_all_param_values(simple_net_output))
342341
np.savez(os.path.join(savepath , "fcn1D_errors_best.npz"),
343342
err_train=err_train, acc_train=acc_train,
344343
err_valid=err_valid, acc_valid=acc_valid, jacc_valid=jacc_valid)
345-
np.savez(os.path.join(savepath, 'new_fcn1D_model_last.npz'),
346-
*lasagne.layers.get_all_param_values(simple_net_output))
347-
np.savez(os.path.join(savepath , "fcn1D_errors_last.npz"),
348-
err_train=err_train, acc_train=acc_train,
349-
err_valid=err_valid, acc_valid=acc_valid, jacc_valid=jacc_valid)
350344
else:
351345
patience += 1
352346
print('saving last model')
353-
np.savez(os.path.join(savepath, 'new_fcn1D_model_last.npz'),
354-
*lasagne.layers.get_all_param_values(simple_net_output))
355-
np.savez(os.path.join(savepath , "fcn1D_errors_last.npz"),
356-
err_train=err_train, acc_train=acc_train,
357-
err_valid=err_valid, acc_valid=acc_valid, jacc_valid=jacc_valid)
347+
348+
np.savez(os.path.join(savepath, 'new_fcn1D_model_last.npz'), *lasagne.layers.get_all_param_values(simple_net_output))
349+
np.savez(os.path.join(savepath , "fcn1D_errors_last.npz"),
350+
err_train=err_train, acc_train=acc_train,
351+
err_valid=err_valid, acc_valid=acc_valid, jacc_valid=jacc_valid)
358352
# Finish training if patience has expired or max nber of epochs reached
359353

360354
if patience == max_patience or epoch == num_epochs-1:

code/fcn_2D_segm/train_fcn8.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ def train(dataset, learn_step=0.005,
213213
jacc_valid = []
214214
patience = 0
215215

216+
n_batches_train = 1
217+
n_batches_val = 1
218+
n_batches_test = 1
219+
num_epochs = 1
216220
# Training main loop
217221
print "Start training"
218222
for epoch in range(num_epochs):
@@ -222,10 +226,12 @@ def train(dataset, learn_step=0.005,
222226

223227
# Train
224228
for i in range(n_batches_train):
229+
print 'Training batch ', i
225230
# Get minibatch
226231
X_train_batch, L_train_batch = train_iter.next()
227232
L_train_batch = np.reshape(L_train_batch, np.prod(L_train_batch.shape))
228233

234+
229235
# Training step
230236
cost_train = train_fn(X_train_batch, L_train_batch)
231237
out_str = "cost %f" % (cost_train)
@@ -238,6 +244,7 @@ def train(dataset, learn_step=0.005,
238244
acc_val_tot = 0
239245
jacc_val_tot = np.zeros((2, n_classes))
240246
for i in range(n_batches_val):
247+
print 'Valid batch ', i
241248
# Get minibatch
242249
X_val_batch, L_val_batch = val_iter.next()
243250
L_val_batch = np.reshape(L_val_batch, np.prod(L_val_batch.shape))
@@ -277,13 +284,12 @@ def train(dataset, learn_step=0.005,
277284
best_jacc_val = jacc_valid[epoch]
278285
patience = 0
279286
np.savez(os.path.join(savepath, 'new_fcn8_model_best.npz'), *lasagne.layers.get_all_param_values(convmodel))
280-
np.savez(os.path.join(savepath + "fcn8_errors_best.npz"),
281-
err_valid, err_train, acc_valid, jacc_valid)
287+
np.savez(os.path.join(savepath, "fcn8_errors_best.npz"), err_valid, err_train, acc_valid, jacc_valid)
282288
else:
283289
patience += 1
284-
np.savez(os.path.join(savepath, 'new_fcn8_model_last.npz'), *lasagne.layers.get_all_param_values(convmodel))
285-
np.savez(os.path.join(savepath + "fcn8_errors_last.npz"),
286-
err_valid, err_train, acc_valid, jacc_valid)
290+
291+
np.savez(os.path.join(savepath, 'new_fcn8_model_last.npz'), *lasagne.layers.get_all_param_values(convmodel))
292+
np.savez(os.path.join(savepath, "fcn8_errors_last.npz"), err_valid, err_train, acc_valid, jacc_valid)
287293
# Finish training if patience has expired or max nber of epochs
288294
# reached
289295
if patience == max_patience or epoch == num_epochs-1:

code/unet/train_unet.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,6 @@ def train(dataset, learn_step=0.005,
226226
# Single epoch training and validation
227227
start_time = time.time()
228228
cost_train_tot = 0
229-
230-
n_batches_train = 2
231-
n_batches_val = 2
232229
# Train
233230
print 'Training steps '
234231
for i in range(n_batches_train):
@@ -291,13 +288,12 @@ def train(dataset, learn_step=0.005,
291288
best_jacc_val = jacc_valid[epoch]
292289
patience = 0
293290
np.savez(os.path.join(savepath, 'new_unet_model_best.npz'), *lasagne.layers.get_all_param_values(output_layer))
294-
np.savez(os.path.join(savepath + "unet_errors_best.npz"),
295-
err_valid, err_train, acc_valid, jacc_valid)
291+
np.savez(os.path.join(savepath, 'unet_errors_best.npz'), err_valid, err_train, acc_valid, jacc_valid)
296292
else:
297293
patience += 1
298-
np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(output_layer))
299-
np.savez(os.path.join(savepath + "unet_errors_last.npz"),
300-
err_valid, err_train, acc_valid, jacc_valid)
294+
295+
np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(output_layer))
296+
np.savez(os.path.join(savepath, 'unet_errors_last.npz'), err_valid, err_train, acc_valid, jacc_valid)
301297
# Finish training if patience has expired or max nber of epochs
302298
# reached
303299
if patience == max_patience or epoch == num_epochs-1:

0 commit comments

Comments
 (0)