Skip to content

Commit 3f8dd64

Browse files
StephanieLarocquenotoraptor
authored andcommitted
fix input dim and details
1 parent 70495d8 commit 3f8dd64

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

code/unet/train_unet.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def jaccard_metric(y_pred, y_true, n_classes, one_hot=False):
3434
y_true = T.argmax(y_true, axis=1)
3535

3636
# Compute confusion matrix
37+
# cm = T.nnet.confusion_matrix(y_pred, y_true)
3738
cm = T.zeros((n_classes, n_classes))
3839
for i in range(n_classes):
3940
for j in range(n_classes):
@@ -156,12 +157,16 @@ def train(dataset, learn_step=0.005,
156157
if batch_size is not None:
157158
bs = batch_size
158159
else:
159-
bs = [10, 1, 1]
160+
bs = [1, 1, 1]
160161

161162
train_iter, val_iter, test_iter = \
162163
load_data(dataset, data_augmentation,
163164
one_hot=False, batch_size=bs, return_0_255=train_from_0_255)
164165

166+
batch = train_iter.next()
167+
input_dim = (np.shape(batch[0])[2], np.shape(batch[0])[3])
168+
print 'batch size ', np.shape(batch[0])
169+
165170
n_batches_train = train_iter.nbatches
166171
n_batches_val = val_iter.nbatches
167172
n_batches_test = test_iter.nbatches if test_iter is not None else 0
@@ -176,29 +181,30 @@ def train(dataset, learn_step=0.005,
176181
#
177182
# Build network
178183
#
179-
convmodel = build_UNet(nb_in_channels, input_var, n_classes=n_classes,
180-
void_labels=void_labels, trainable=True,
181-
load_weights=resume, pascal=True, layer=['probs'])
182184

185+
net = build_UNet(n_input_channels= nb_in_channels,# BATCH_SIZE = batch_size,
186+
num_output_classes = n_classes, base_n_filters = 64, do_dropout=False,
187+
input_dim =input_dim) #(512,512))
188+
189+
output_layer = net["output_flattened"]
183190
#
184191
# Define and compile theano functions
185192
#
186193
print "Defining and compiling training functions"
187-
prediction = lasagne.layers.get_output(convmodel)[0]
194+
prediction = lasagne.layers.get_output(output_layer, input_var)
188195
loss = crossentropy_metric(prediction, target_var, void_labels)
189196

190197
if weight_decay > 0:
191-
weightsl2 = regularize_network_params(
192-
convmodel, lasagne.regularization.l2)
198+
weightsl2 = regularize_network_params(output_layer, lasagne.regularization.l2)
193199
loss += weight_decay * weightsl2
194200

195-
params = lasagne.layers.get_all_params(convmodel, trainable=True)
201+
params = lasagne.layers.get_all_params(output_layer, trainable=True)
196202
updates = lasagne.updates.adam(loss, params, learning_rate=learn_step)
197203

198204
train_fn = theano.function([input_var, target_var], loss, updates=updates)
199205

200206
print "Defining and compiling test functions"
201-
test_prediction = lasagne.layers.get_output(convmodel, deterministic=True)[0]
207+
test_prediction = lasagne.layers.get_output(output_layer, input_var,deterministic=True)
202208
test_loss = crossentropy_metric(test_prediction, target_var, void_labels)
203209
test_acc = accuracy_metric(test_prediction, target_var, void_labels)
204210
test_jacc = jaccard_metric(test_prediction, target_var, n_classes)
@@ -221,8 +227,12 @@ def train(dataset, learn_step=0.005,
221227
start_time = time.time()
222228
cost_train_tot = 0
223229

230+
n_batches_train = 2
231+
n_batches_val = 2
224232
# Train
233+
print 'Training steps '
225234
for i in range(n_batches_train):
235+
print i
226236
# Get minibatch
227237
X_train_batch, L_train_batch = train_iter.next()
228238
L_train_batch = np.reshape(L_train_batch, np.prod(L_train_batch.shape))
@@ -238,7 +248,10 @@ def train(dataset, learn_step=0.005,
238248
cost_val_tot = 0
239249
acc_val_tot = 0
240250
jacc_val_tot = np.zeros((2, n_classes))
251+
252+
print 'Validation steps'
241253
for i in range(n_batches_val):
254+
print i
242255
# Get minibatch
243256
X_val_batch, L_val_batch = val_iter.next()
244257
L_val_batch = np.reshape(L_val_batch, np.prod(L_val_batch.shape))
@@ -277,12 +290,12 @@ def train(dataset, learn_step=0.005,
277290
elif epoch > 1 and jacc_valid[epoch] > best_jacc_val:
278291
best_jacc_val = jacc_valid[epoch]
279292
patience = 0
280-
np.savez(os.path.join(savepath, 'new_unet_model_best.npz'), *lasagne.layers.get_all_param_values(convmodel))
293+
np.savez(os.path.join(savepath, 'new_unet_model_best.npz'), *lasagne.layers.get_all_param_values(output_layer))
281294
np.savez(os.path.join(savepath + "unet_errors_best.npz"),
282295
err_valid, err_train, acc_valid, jacc_valid)
283296
else:
284297
patience += 1
285-
np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(convmodel))
298+
np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(output_layer))
286299
np.savez(os.path.join(savepath + "unet_errors_last.npz"),
287300
err_valid, err_train, acc_valid, jacc_valid)
288301
# Finish training if patience has expired or max nber of epochs
@@ -292,8 +305,8 @@ def train(dataset, learn_step=0.005,
292305
# Load best model weights
293306
with np.load(os.path.join(savepath, 'new_unet_model_best.npz')) as f:
294307
param_values = [f['arr_%d' % i] for i in range(len(f.files))]
295-
nlayers = len(lasagne.layers.get_all_params(convmodel))
296-
lasagne.layers.set_all_param_values(convmodel, param_values[:nlayers])
308+
nlayers = len(lasagne.layers.get_all_params(output_layer))
309+
lasagne.layers.set_all_param_values(output_layer, param_values[:nlayers])
297310
# Test
298311
cost_test_tot = 0
299312
acc_test_tot = 0
@@ -318,13 +331,11 @@ def train(dataset, learn_step=0.005,
318331
jacc_test = np.mean(jacc_num_test_tot / jacc_denom_test_tot)
319332

320333
out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
321-
out_str = out_str % (err_test,
322-
acc_test,
323-
jacc_test)
334+
out_str = out_str % (err_test, acc_test, jacc_test)
324335
print out_str
325-
# if savepath != loadpath:
326-
# print('Copying model and other training files to {}'.format(loadpath))
327-
# copy_tree(savepath, loadpath)
336+
if savepath != loadpath:
337+
print('Copying model and other training files to {}'.format(loadpath))
338+
copy_tree(savepath, loadpath)
328339

329340
# End
330341
return
@@ -353,11 +364,18 @@ def main():
353364
help='Max patience')
354365
parser.add_argument('-batch_size',
355366
type=int,
356-
default=[10, 1, 1],
367+
default=[1, 1, 1],
357368
help='Batch size [train, val, test]')
358369
parser.add_argument('-data_augmentation',
359370
type=dict,
360-
default={'crop_size': (224, 224), 'horizontal_flip': True, 'fill_mode':'constant'},
371+
default={'rotation_range':25,
372+
'shear_range':0.41,
373+
'horizontal_flip':True,
374+
'vertical_flip':True,
375+
'fill_mode':'reflect',
376+
'spline_warp':True,
377+
'warp_sigma':10,
378+
'warp_grid_size':3},
361379
help='use data augmentation')
362380
parser.add_argument('-early_stop_class',
363381
type=int,

0 commit comments

Comments
 (0)