Skip to content

Commit 6637ab1

Browse files
committed
complete stats calculation and plotting for comparitors
1 parent ab5e3ee commit 6637ab1

File tree

2 files changed

+110
-16
lines changed

2 files changed

+110
-16
lines changed

wfdb/processing/evaluate.py

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
11
import numpy as np
22
import matplotlib.pyplot as plt
33

4+
import pdb
5+
46

57
class Comparitor(object):
68

7-
def __init__(self, ref_sample, test_sample, window_width):
9+
def __init__(self, ref_sample, test_sample, window_width, signal=None):
810
"""
911
Parameters
1012
----------
1113
ref_sample : numpy array
1214
An array of the reference sample locations
1315
test_sample : numpy array
1416
An array of the comparison sample locations
17+
window_width : int
18+
The width of the window
1519
"""
1620
if min(np.diff(ref_sample)) < 0 or min(np.diff(test_sample)) < 0:
17-
raise ValueError(("The sample locations must be monotonically"
18-
+ " increasing"))
21+
raise ValueError(('The sample locations must be monotonically'
22+
+ ' increasing'))
1923

2024
self.ref_sample = ref_sample
2125
self.test_sample = test_sample
2226
self.n_ref = len(ref_sample)
2327
self.n_test = len(test_sample)
2428
self.window_width = window_width
2529

26-
# The matching test sample numbers. -1 for indices with no match
27-
self.matching_sample_nums = -1 * np.ones(self.n_ref)
30+
# The matching test sample number for each reference annotation.
31+
# -1 for indices with no match
32+
self.matching_sample_nums = -1 * np.ones(self.n_ref, dtype='int')
2833

34+
self.signal = signal
2935
# TODO: rdann return annotations.where
3036

3137
def calc_stats(self):
@@ -38,7 +44,7 @@ def calc_stats(self):
3844
ref=500 test=480
3945
{ 30 { 470 } 10 }
4046
-------------------
41-
47+
4248
tp = 470
4349
fp = 10
4450
fn = 30
@@ -48,15 +54,25 @@ def calc_stats(self):
4854
false_positive_rate = 10 / 480
4955
5056
"""
51-
self.detected_ref_inds = np.where(self.matching_sample_nums != -1)
52-
self.missed_ref_inds = np.where(self.matching_sample_nums == -1)
53-
self.matched_test_inds = self.matching_sample_nums(
54-
self.matching_sample_nums != -1)
55-
self.unmached_test_inds = np.setdiff1d(np.array(range(self.n_test)),
57+
# Reference annotation indices that were detected
58+
self.matched_ref_inds = np.where(self.matching_sample_nums != -1)[0]
59+
# Reference annotation indices that were missed
60+
self.unmatched_ref_inds = np.where(self.matching_sample_nums == -1)[0]
61+
# Test annotation indices that were matched to a reference annotation
62+
self.matched_test_inds = self.matching_sample_nums[
63+
self.matching_sample_nums != -1]
64+
# Test annotation indices that were unmatched to a reference annotation
65+
self.unmatched_test_inds = np.setdiff1d(np.array(range(self.n_test)),
5666
self.matched_test_inds, assume_unique=True)
67+
68+
# Sample numbers that were matched and unmatched
69+
self.matched_ref_sample = self.ref_sample[self.matched_ref_inds]
70+
self.unmatched_ref_sample = self.ref_sample[self.unmatched_ref_inds]
71+
self.matched_test_sample = self.test_sample[self.matched_test_inds]
72+
self.unmatched_test_sample = self.test_sample[self.unmatched_test_inds]
5773

5874
# True positives = matched reference samples
59-
self.tp = len(detected_ref_inds)
75+
self.tp = len(self.matched_ref_inds)
6076
# False positives = extra test samples not matched
6177
self.fp = self.n_test - self.tp
6278
# False negatives = undetected reference samples
@@ -69,7 +85,9 @@ def calc_stats(self):
6985

7086

7187
def compare(self):
72-
88+
"""
89+
Main comparison function
90+
"""
7391
test_samp_num = 0
7492
ref_samp_num = 0
7593

@@ -95,7 +113,6 @@ def compare(self):
95113
self.get_closest_samp_num(ref_samp_num, test_samp_num,
96114
closest_samp_num))
97115

98-
99116
self.matching_sample_nums[ref_samp_num] = closest_samp_num
100117

101118
# If no clash, it is straightforward.
@@ -146,6 +163,83 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num,
146163

147164
return closest_samp_num, smallest_samp_diff
148165

166+
def print_summary(self):
167+
# True positives = matched reference samples
168+
self.tp = len(self.matched_ref_inds)
169+
# False positives = extra test samples not matched
170+
self.fp = self.n_test - self.tp
171+
# False negatives = undetected reference samples
172+
self.fn = self.n_ref - self.tp
173+
# No tn attribute
174+
175+
self.specificity = self.tp / self.n_ref
176+
self.positive_predictivity = self.tp / self.n_test
177+
self.false_positive_rate = self.fp / self.n_test
178+
179+
print('%d reference annotations, %d test annotations\n'
180+
% (self.n_ref, self.n_test))
181+
print('True Positives (matched samples): %d' % self.tp)
182+
print('False Positives (unmatched test samples: %d' % self.fp)
183+
print('False Negatives (unmatched reference samples): %d\n' % self.fn)
184+
185+
print('Specificity: %.4f (%d/%d)'
186+
% (self.specificity, self.tp, self.n_ref))
187+
print('Positive Predictivity: %.4f (%d/%d)'
188+
% (self.positive_predictivity, self.tp, self.n_test))
189+
print('False Positive Rate: %.4f (%d/%d)'
190+
% (self.false_positive_rate, self.fp, self.n_test))
191+
192+
193+
def plot(self, signal=None, sig_style='', title=None, figsize=None,
194+
return_fig=False):
195+
"""
196+
Plot results of two sets of annotations
197+
"""
198+
if signal is not None:
199+
self.signal = signal
200+
201+
fig = plt.figure(figsize=figsize)
202+
ax = fig.add_subplot(1, 1, 1)
203+
204+
legend = ['Signal',
205+
'Matched Reference Annotations (%d/%d)' % (self.tp, self.n_ref),
206+
'Unmatched Reference Annotations (%d/%d)' % (self.fn, self.n_ref),
207+
'Matched Test Annotations (%d/%d)' % (self.tp, self.n_test),
208+
'Unmatched Test Annotations (%d/%d)' % (self.fp, self.n_test)
209+
]
210+
211+
# Plot the signal if any
212+
if self.signal is not None:
213+
ax.plot(self.signal, sig_style)
214+
215+
# Plot reference annotations
216+
ax.plot(self.matched_ref_sample,
217+
self.signal[self.matched_ref_sample], 'ko')
218+
ax.plot(self.unmatched_ref_sample,
219+
self.signal[self.unmatched_ref_sample], 'ko', fillstyle='none')
220+
# Plot test annotations
221+
ax.plot(self.matched_test_sample,
222+
self.signal[self.matched_test_sample], 'g+')
223+
ax.plot(self.unmatched_test_sample,
224+
self.signal[self.unmatched_test_sample], 'rx')
225+
226+
ax.legend(legend)
227+
228+
# Just plot annotations
229+
else:
230+
# Plot reference annotations
231+
ax.plot(self.matched_ref_sample, np.ones(self.tp), 'ko')
232+
ax.plot(self.unmatched_ref_sample, np.ones(self.fn), 'ko', fillstyle='none')
233+
# Plot test annotations
234+
ax.plot(self.matched_test_sample, 0.5 * np.ones(self.tp), 'g+')
235+
ax.plot(self.unmatched_test_sample, 0.5 * np.ones(self.fp), 'rx')
236+
ax.legend(legend[1:])
237+
238+
fig.show()
239+
240+
if return_fig:
241+
return fig, ax
242+
149243

150244
def compare_annotations(ref_sample, test_sample, window_width):
151245
"""
@@ -167,7 +261,7 @@ def compare_annotations(ref_sample, test_sample, window_width):
167261
# def plot_record(record=None, title=None, annotation=None, time_units='samples',
168262
# sig_style='', ann_style='r*', plot_ann_sym=False, figsize=None,
169263
# return_fig=False, ecg_grids=[]):
170-
def plot_comparitor(comparitor, sig=None, sig_style='', title=None, figsize=None, return_fig=False):
264+
def plot_comparitor(comparitor, ):
171265
"""
172266
Plot two sets of annotations
173267

wfdb/processing/gqrs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import numpy
33
import copy
44

5-
65
from ..io.record import Record
76

7+
88
def time_to_sample_number(seconds, frequency):
99
return seconds * frequency + 0.5
1010

0 commit comments

Comments
 (0)