Skip to content

Commit c1a2b19

Browse files
StephanieLarocquenotoraptor
authored andcommitted
accuracy import fix
1 parent bbdf0e5 commit c1a2b19

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

code/cnn_1D_segm/train_fcn1D.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from theano import config
2121
import lasagne
2222
from lasagne.regularization import regularize_network_params
23-
from lasagne.objectives import categorical_crossentropy, accuracy
23+
from lasagne.objectives import categorical_crossentropy
2424

2525
import PIL.Image as Image
2626
from matplotlib import pyplot as plt
@@ -32,6 +32,32 @@
3232

3333

3434

35+
def accuracy_metric(y_pred, y_true, void_labels, one_hot=False):
36+
37+
assert (y_pred.ndim == 2) or (y_pred.ndim == 1)
38+
39+
# y_pred to indices
40+
if y_pred.ndim == 2:
41+
y_pred = T.argmax(y_pred, axis=1)
42+
43+
if one_hot:
44+
y_true = T.argmax(y_true, axis=1)
45+
46+
# Compute accuracy
47+
acc = T.eq(y_pred, y_true).astype(_FLOATX)
48+
49+
# Create mask
50+
mask = T.ones_like(y_true, dtype=_FLOATX)
51+
for el in void_labels:
52+
indices = T.eq(y_true, el).nonzero()
53+
if any(indices):
54+
mask = T.set_subtensor(mask[indices], 0.)
55+
56+
# Apply mask
57+
acc *= mask
58+
acc = T.sum(acc) / T.sum(mask)
59+
60+
return acc
3561

3662
# In[2]:
3763

@@ -157,7 +183,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
157183
which_set='train',
158184
smooth_or_raw = smooth_or_raw,
159185
batch_size=batch_size[0],
160-
data_augm_kwargs=data_augmentation,
186+
data_augm_kwargs={},
161187
shuffle_at_each_epoch = True,
162188
return_one_hot=False,
163189
return_01c=False,
@@ -211,7 +237,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
211237
simple_net_output, lasagne.regularization.l2)
212238
loss += weight_decay * weightsl2
213239

214-
train_acc = accuracy(prediction, target_var, void_labels)
240+
train_acc = accuracy_metric(prediction, target_var, void_labels)
215241

216242
params = lasagne.layers.get_all_params(simple_net_output, trainable=True)
217243
updates = lasagne.updates.adam(loss, params, learning_rate=learn_step)
@@ -224,7 +250,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
224250
print "Defining and compiling valid functions"
225251
valid_prediction = lasagne.layers.get_output(simple_net_output[0],deterministic=True)
226252
valid_loss = categorical_crossentropy(valid_prediction, target_var).mean()
227-
valid_acc = accuracy(valid_prediction, target_var, void_labels)
253+
valid_acc = accuracy_metric(valid_prediction, target_var, void_labels)
228254
valid_jacc = jaccard(valid_prediction, target_var, n_classes)
229255

230256
valid_fn = theano.function([input_var, target_var], [valid_loss, valid_acc,valid_jacc])

0 commit comments

Comments
 (0)