1
1
import numpy as np
2
2
import matplotlib .pyplot as plt
3
3
4
+ import pdb
5
+
4
6
5
7
class Comparitor (object ):
6
8
7
- def __init__ (self , ref_sample , test_sample , window_width ):
9
+ def __init__ (self , ref_sample , test_sample , window_width , signal = None ):
8
10
"""
9
11
Parameters
10
12
----------
11
13
ref_sample : numpy array
12
14
An array of the reference sample locations
13
15
test_sample : numpy array
14
16
An array of the comparison sample locations
17
+ window_width : int
18
+ The width of the window
15
19
"""
16
20
if min (np .diff (ref_sample )) < 0 or min (np .diff (test_sample )) < 0 :
17
- raise ValueError ((" The sample locations must be monotonically"
18
- + " increasing" ))
21
+ raise ValueError ((' The sample locations must be monotonically'
22
+ + ' increasing' ))
19
23
20
24
self .ref_sample = ref_sample
21
25
self .test_sample = test_sample
22
26
self .n_ref = len (ref_sample )
23
27
self .n_test = len (test_sample )
24
28
self .window_width = window_width
25
29
26
- # The matching test sample numbers. -1 for indices with no match
27
- self .matching_sample_nums = - 1 * np .ones (self .n_ref )
30
+ # The matching test sample number for each reference annotation.
31
+ # -1 for indices with no match
32
+ self .matching_sample_nums = - 1 * np .ones (self .n_ref , dtype = 'int' )
28
33
34
+ self .signal = signal
29
35
# TODO: rdann return annotations.where
30
36
31
37
def calc_stats (self ):
@@ -38,7 +44,7 @@ def calc_stats(self):
38
44
ref=500 test=480
39
45
{ 30 { 470 } 10 }
40
46
-------------------
41
-
47
+
42
48
tp = 470
43
49
fp = 10
44
50
fn = 30
@@ -48,15 +54,25 @@ def calc_stats(self):
48
54
false_positive_rate = 10 / 480
49
55
50
56
"""
51
- self .detected_ref_inds = np .where (self .matching_sample_nums != - 1 )
52
- self .missed_ref_inds = np .where (self .matching_sample_nums == - 1 )
53
- self .matched_test_inds = self .matching_sample_nums (
54
- self .matching_sample_nums != - 1 )
55
- self .unmached_test_inds = np .setdiff1d (np .array (range (self .n_test )),
57
+ # Reference annotation indices that were detected
58
+ self .matched_ref_inds = np .where (self .matching_sample_nums != - 1 )[0 ]
59
+ # Reference annotation indices that were missed
60
+ self .unmatched_ref_inds = np .where (self .matching_sample_nums == - 1 )[0 ]
61
+ # Test annotation indices that were matched to a reference annotation
62
+ self .matched_test_inds = self .matching_sample_nums [
63
+ self .matching_sample_nums != - 1 ]
64
+ # Test annotation indices that were unmatched to a reference annotation
65
+ self .unmatched_test_inds = np .setdiff1d (np .array (range (self .n_test )),
56
66
self .matched_test_inds , assume_unique = True )
67
+
68
+ # Sample numbers that were matched and unmatched
69
+ self .matched_ref_sample = self .ref_sample [self .matched_ref_inds ]
70
+ self .unmatched_ref_sample = self .ref_sample [self .unmatched_ref_inds ]
71
+ self .matched_test_sample = self .test_sample [self .matched_test_inds ]
72
+ self .unmatched_test_sample = self .test_sample [self .unmatched_test_inds ]
57
73
58
74
# True positives = matched reference samples
59
- self .tp = len (detected_ref_inds )
75
+ self .tp = len (self . matched_ref_inds )
60
76
# False positives = extra test samples not matched
61
77
self .fp = self .n_test - self .tp
62
78
# False negatives = undetected reference samples
@@ -69,7 +85,9 @@ def calc_stats(self):
69
85
70
86
71
87
def compare (self ):
72
-
88
+ """
89
+ Main comparison function
90
+ """
73
91
test_samp_num = 0
74
92
ref_samp_num = 0
75
93
@@ -95,7 +113,6 @@ def compare(self):
95
113
self .get_closest_samp_num (ref_samp_num , test_samp_num ,
96
114
closest_samp_num ))
97
115
98
-
99
116
self .matching_sample_nums [ref_samp_num ] = closest_samp_num
100
117
101
118
# If no clash, it is straightforward.
@@ -146,6 +163,83 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num,
146
163
147
164
return closest_samp_num , smallest_samp_diff
148
165
166
+ def print_summary (self ):
167
+ # True positives = matched reference samples
168
+ self .tp = len (self .matched_ref_inds )
169
+ # False positives = extra test samples not matched
170
+ self .fp = self .n_test - self .tp
171
+ # False negatives = undetected reference samples
172
+ self .fn = self .n_ref - self .tp
173
+ # No tn attribute
174
+
175
+ self .specificity = self .tp / self .n_ref
176
+ self .positive_predictivity = self .tp / self .n_test
177
+ self .false_positive_rate = self .fp / self .n_test
178
+
179
+ print ('%d reference annotations, %d test annotations\n '
180
+ % (self .n_ref , self .n_test ))
181
+ print ('True Positives (matched samples): %d' % self .tp )
182
+ print ('False Positives (unmatched test samples: %d' % self .fp )
183
+ print ('False Negatives (unmatched reference samples): %d\n ' % self .fn )
184
+
185
+ print ('Specificity: %.4f (%d/%d)'
186
+ % (self .specificity , self .tp , self .n_ref ))
187
+ print ('Positive Predictivity: %.4f (%d/%d)'
188
+ % (self .positive_predictivity , self .tp , self .n_test ))
189
+ print ('False Positive Rate: %.4f (%d/%d)'
190
+ % (self .false_positive_rate , self .fp , self .n_test ))
191
+
192
+
193
+ def plot (self , signal = None , sig_style = '' , title = None , figsize = None ,
194
+ return_fig = False ):
195
+ """
196
+ Plot results of two sets of annotations
197
+ """
198
+ if signal is not None :
199
+ self .signal = signal
200
+
201
+ fig = plt .figure (figsize = figsize )
202
+ ax = fig .add_subplot (1 , 1 , 1 )
203
+
204
+ legend = ['Signal' ,
205
+ 'Matched Reference Annotations (%d/%d)' % (self .tp , self .n_ref ),
206
+ 'Unmatched Reference Annotations (%d/%d)' % (self .fn , self .n_ref ),
207
+ 'Matched Test Annotations (%d/%d)' % (self .tp , self .n_test ),
208
+ 'Unmatched Test Annotations (%d/%d)' % (self .fp , self .n_test )
209
+ ]
210
+
211
+ # Plot the signal if any
212
+ if self .signal is not None :
213
+ ax .plot (self .signal , sig_style )
214
+
215
+ # Plot reference annotations
216
+ ax .plot (self .matched_ref_sample ,
217
+ self .signal [self .matched_ref_sample ], 'ko' )
218
+ ax .plot (self .unmatched_ref_sample ,
219
+ self .signal [self .unmatched_ref_sample ], 'ko' , fillstyle = 'none' )
220
+ # Plot test annotations
221
+ ax .plot (self .matched_test_sample ,
222
+ self .signal [self .matched_test_sample ], 'g+' )
223
+ ax .plot (self .unmatched_test_sample ,
224
+ self .signal [self .unmatched_test_sample ], 'rx' )
225
+
226
+ ax .legend (legend )
227
+
228
+ # Just plot annotations
229
+ else :
230
+ # Plot reference annotations
231
+ ax .plot (self .matched_ref_sample , np .ones (self .tp ), 'ko' )
232
+ ax .plot (self .unmatched_ref_sample , np .ones (self .fn ), 'ko' , fillstyle = 'none' )
233
+ # Plot test annotations
234
+ ax .plot (self .matched_test_sample , 0.5 * np .ones (self .tp ), 'g+' )
235
+ ax .plot (self .unmatched_test_sample , 0.5 * np .ones (self .fp ), 'rx' )
236
+ ax .legend (legend [1 :])
237
+
238
+ fig .show ()
239
+
240
+ if return_fig :
241
+ return fig , ax
242
+
149
243
150
244
def compare_annotations (ref_sample , test_sample , window_width ):
151
245
"""
@@ -167,7 +261,7 @@ def compare_annotations(ref_sample, test_sample, window_width):
167
261
# def plot_record(record=None, title=None, annotation=None, time_units='samples',
168
262
# sig_style='', ann_style='r*', plot_ann_sym=False, figsize=None,
169
263
# return_fig=False, ecg_grids=[]):
170
- def plot_comparitor (comparitor , sig = None , sig_style = '' , title = None , figsize = None , return_fig = False ):
264
+ def plot_comparitor (comparitor , ):
171
265
"""
172
266
Plot two sets of annotations
173
267
0 commit comments