20
20
from theano import config
21
21
import lasagne
22
22
from lasagne .regularization import regularize_network_params
23
- from lasagne .objectives import categorical_crossentropy , accuracy
23
+ from lasagne .objectives import categorical_crossentropy
24
24
25
25
import PIL .Image as Image
26
26
from matplotlib import pyplot as plt
32
32
33
33
34
34
35
+ def accuracy_metric (y_pred , y_true , void_labels , one_hot = False ):
36
+
37
+ assert (y_pred .ndim == 2 ) or (y_pred .ndim == 1 )
38
+
39
+ # y_pred to indices
40
+ if y_pred .ndim == 2 :
41
+ y_pred = T .argmax (y_pred , axis = 1 )
42
+
43
+ if one_hot :
44
+ y_true = T .argmax (y_true , axis = 1 )
45
+
46
+ # Compute accuracy
47
+ acc = T .eq (y_pred , y_true ).astype (_FLOATX )
48
+
49
+ # Create mask
50
+ mask = T .ones_like (y_true , dtype = _FLOATX )
51
+ for el in void_labels :
52
+ indices = T .eq (y_true , el ).nonzero ()
53
+ if any (indices ):
54
+ mask = T .set_subtensor (mask [indices ], 0. )
55
+
56
+ # Apply mask
57
+ acc *= mask
58
+ acc = T .sum (acc ) / T .sum (mask )
59
+
60
+ return acc
35
61
36
62
# In[2]:
37
63
@@ -157,7 +183,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
157
183
which_set = 'train' ,
158
184
smooth_or_raw = smooth_or_raw ,
159
185
batch_size = batch_size [0 ],
160
- data_augm_kwargs = data_augmentation ,
186
+ data_augm_kwargs = {} ,
161
187
shuffle_at_each_epoch = True ,
162
188
return_one_hot = False ,
163
189
return_01c = False ,
@@ -211,7 +237,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
211
237
simple_net_output , lasagne .regularization .l2 )
212
238
loss += weight_decay * weightsl2
213
239
214
- train_acc = accuracy (prediction , target_var , void_labels )
240
+ train_acc = accuracy_metric (prediction , target_var , void_labels )
215
241
216
242
params = lasagne .layers .get_all_params (simple_net_output , trainable = True )
217
243
updates = lasagne .updates .adam (loss , params , learning_rate = learn_step )
@@ -224,7 +250,7 @@ def jaccard(y_pred, y_true, n_classes, one_hot=False):
224
250
print "Defining and compiling valid functions"
225
251
valid_prediction = lasagne .layers .get_output (simple_net_output [0 ],deterministic = True )
226
252
valid_loss = categorical_crossentropy (valid_prediction , target_var ).mean ()
227
- valid_acc = accuracy (valid_prediction , target_var , void_labels )
253
+ valid_acc = accuracy_metric (valid_prediction , target_var , void_labels )
228
254
valid_jacc = jaccard (valid_prediction , target_var , n_classes )
229
255
230
256
valid_fn = theano .function ([input_var , target_var ], [valid_loss , valid_acc ,valid_jacc ])
0 commit comments