1
1
import numpy as np
2
-
3
- class Comparitor (object ):
2
+ import matplotlib .pyplot as plt
4
3
5
4
5
+ class Comparitor (object ):
6
6
7
7
def __init__ (self , ref_sample , test_sample , window_width ):
8
8
"""
@@ -19,65 +19,103 @@ def __init__(self, ref_sample, test_sample, window_width):
19
19
20
20
self .ref_sample = ref_sample
21
21
self .test_sample = test_sample
22
-
23
- self .fp = 0
24
- self .tp = 0
25
-
26
- self .n_missed = 0
27
- self .n_detected = 0
28
-
29
- # How many there are
30
22
self .n_ref = len (ref_sample )
31
23
self .n_comp = len (test_sample )
32
24
33
-
34
- # # Just derive these 4 at the end?
35
- # # Index info about the reference samples
36
- # self.detected_inds = []
37
- # self.missed_inds = []
38
- # # About the testing samples
39
- # self.correct_test_inds = []
40
- # self.wrong_test_inds = []
41
-
42
-
43
25
# The matching test sample numbers. -1 for indices with no match
44
26
self .matching_sample_nums = - 1 * np .ones (n_ref )
45
27
46
28
# TODO: rdann return annotations.where
47
29
48
- def compare (self ):
30
+ def calc_stats (self ):
31
+ """
32
+ Calculate performance statistics after the two sets of annotations
33
+ are compared.
34
+
35
+ Example:
36
+ -------------------
37
+ ref=500 test=480
38
+ { 30 { 470 } 10 }
39
+ -------------------
40
+
41
+ tp = 470
42
+ fp = 10
43
+ fn = 30
44
+
45
+ specificity = 470 / 500
46
+ positive_predictivity = 470 / 480
47
+ false_positive_rate = 10 / 480
48
+
49
+ """
50
+ self .detected_ref_inds = np .where (self .matching_sample_nums != - 1 )
51
+ self .missed_ref_inds = np .where (self .matching_sample_nums == - 1 )
52
+ self .matched_test_inds = self .matching_sample_nums (
53
+ self .matching_sample_nums != - 1 )
54
+ self .unmached_test_inds = np .setdiff1d (np .array (range (self .n_test )),
55
+ self .matched_test_inds , assume_unique = True )
56
+
57
+ # True positives = matched reference samples
58
+ self .tp = len (detected_ref_inds )
59
+ # False positives = extra test samples not matched
60
+ self .fp = self .n_test - self .tp
61
+ # False negatives = undetected reference samples
62
+ self .fn = self .n_ref - self .tp
63
+ # No tn attribute
64
+
65
+ self .specificity = self .tp / self .n_ref
66
+ self .positive_predictivity = self .tp / self .n_test
67
+ self .false_positive_rate = self .fp / self .n_test
49
68
50
69
70
+ def compare (self ):
51
71
52
72
test_samp_num = 0
53
73
ref_samp_num = 0
54
74
55
- while ref_samp_num < n_ref :
75
+ # Why can't this just be a for loop of ref_samp_num?
76
+ while ref_samp_num < n_ref and test_samp_num < n_test :
56
77
57
78
closest_samp_num , smallest_samp_diff = (
58
- self .get_closest_samp_num (ref_samp_num , test_samp_num ))
79
+ self .get_closest_samp_num (ref_samp_num , test_samp_num ,
80
+ self .n_test ))
81
+ # This needs to work for last index
59
82
closest_samp_num_next , smallest_samp_diff_next = (
60
- self .get_closest_samp_num (ref_samp_num + 1 , test_samp_num ))
83
+ self .get_closest_samp_num (ref_samp_num + 1 , test_samp_num ,
84
+ self .n_test ))
61
85
62
86
# Found a contested test sample number. Decide which reference
63
87
# sample it belongs to.
64
88
if closest_samp_num == closest_samp_num_next :
65
- pass
66
- # No clash. Assign the reference-test pair
67
- else :
89
+ # If the sample is closer to the next reference sample, get
90
+ # the next closest sample for this reference sample.
91
+ if smallest_samp_diff_next < smallest_samp_diff :
92
+ # Get the next closest sample.
93
+ # Can this be empty? Need to catch case where nothing left?
94
+ closest_samp_num , smallest_samp_diff = (
95
+ self .get_closest_samp_num (ref_samp_num , test_samp_num ,
96
+ closest_samp_num ))
97
+
98
+
68
99
self .matching_sample_nums [ref_samp_num ] = closest_samp_num
69
100
70
- ref_samp_num += 1
71
- test_samp_num = closest_samp_num + 1
101
+ # If no clash, it is straightforward.
102
+
103
+ # Assign the reference-test pair if close enough
104
+ if smallest_sample_diff < self .window_width :
105
+ self .matching_sample_nums [ref_samp_num ] = closest_samp_num
72
106
107
+ ref_samp_num += 1
108
+ test_samp_num = closest_samp_num + 1
73
109
74
110
self .calc_stats ()
75
111
76
112
77
- def get_closest_samp_num (self , ref_samp_num , start_test_samp_num ):
113
+ def get_closest_samp_num (self , ref_samp_num , start_test_samp_num ,
114
+ stop_test_samp_num ):
78
115
"""
79
116
Return the closest testing sample number for the given reference
80
- sample number. Begin the search from start_test_samp_num.
117
+ sample number. Limit the search between start_test_samp_num and
118
+ stop_test_samp_num.
81
119
"""
82
120
83
121
if start_test_samp_num >= self .n_test :
@@ -92,7 +130,7 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
92
130
smallest_samp_diff = abs (samp_diff )
93
131
94
132
# Iterate through the testing samples
95
- for test_samp_num in range (start_test_samp_num , self . n_test ):
133
+ for test_samp_num in range (start_test_samp_num , stop_test_samp_num ):
96
134
test_samp = self .test_sample [test_samp_num ]
97
135
samp_diff = ref_samp - test_samp
98
136
abs_samp_diff = abs (samp_diff )
@@ -109,32 +147,22 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
109
147
return closest_samp_num , smallest_samp_diff
110
148
111
149
112
-
113
-
114
-
115
-
116
-
117
- def compare_annotations (ind_ref , ind_comp ):
150
+ def compare_annotations (ref_sample , test_sample , window_width ):
118
151
"""
152
+
153
+ Parameters
154
+ ----------
119
155
120
- """
121
-
122
-
123
- detected_inds
124
- missed_inds
125
-
126
-
127
- tp
128
- tn
129
- fp
130
- fn
156
+ Returns
157
+ -------
158
+ comparitor : Comparitor object
159
+ Object containing parameters about the two sets of annotations
131
160
132
- tpr
133
- tnr
134
- fpr
135
- fnr
161
+ """
162
+ comparitor = Comparitor (ref_sample , test_sample , window_width )
163
+ comparitor .compare ()
136
164
137
- return evaluation
165
+ return comparitor
138
166
139
167
140
168
def plot_comparitor (comparitor , sig = None ):
0 commit comments