Skip to content

Commit b7f69bc

Browse files
committed
work on evaluator
1 parent 5f96ed5 commit b7f69bc

File tree

2 files changed

+84
-54
lines changed

2 files changed

+84
-54
lines changed

wfdb/processing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
"""
44
from .basic import (resample_ann, resample_sig, resample_singlechan,
55
resample_multichan, normalize)
6+
from .evaluate import
67
from .gqrs import gqrs_detect
78
from .hr import compute_hr
89
from .peaks import find_peaks, correct_peaks
910
from .qrs import Conf, XQRS, xqrs_detect
11+

wfdb/processing/evaluate.py

Lines changed: 82 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
2-
3-
class Comparitor(object):
2+
import matplotlib.pyplot as plt
43

54

5+
class Comparitor(object):
66

77
def __init__(self, ref_sample, test_sample, window_width):
88
"""
@@ -19,65 +19,103 @@ def __init__(self, ref_sample, test_sample, window_width):
1919

2020
self.ref_sample = ref_sample
2121
self.test_sample = test_sample
22-
23-
self.fp = 0
24-
self.tp = 0
25-
26-
self.n_missed = 0
27-
self.n_detected = 0
28-
29-
# How many there are
3022
self.n_ref = len(ref_sample)
3123
self.n_comp = len(test_sample)
3224

33-
34-
# # Just derive these 4 at the end?
35-
# # Index info about the reference samples
36-
# self.detected_inds = []
37-
# self.missed_inds = []
38-
# # About the testing samples
39-
# self.correct_test_inds = []
40-
# self.wrong_test_inds = []
41-
42-
4325
# The matching test sample numbers. -1 for indices with no match
4426
self.matching_sample_nums = -1 * np.ones(n_ref)
4527

4628
# TODO: rdann return annotations.where
4729

48-
def compare(self):
30+
def calc_stats(self):
31+
"""
32+
Calculate performance statistics after the two sets of annotations
33+
are compared.
34+
35+
Example:
36+
-------------------
37+
ref=500 test=480
38+
{ 30 { 470 } 10 }
39+
-------------------
40+
41+
tp = 470
42+
fp = 10
43+
fn = 30
44+
45+
specificity = 470 / 500
46+
positive_predictivity = 470 / 480
47+
false_positive_rate = 10 / 480
48+
49+
"""
50+
self.detected_ref_inds = np.where(self.matching_sample_nums != -1)
51+
self.missed_ref_inds = np.where(self.matching_sample_nums == -1)
52+
self.matched_test_inds = self.matching_sample_nums(
53+
self.matching_sample_nums != -1)
54+
self.unmached_test_inds = np.setdiff1d(np.array(range(self.n_test)),
55+
self.matched_test_inds, assume_unique=True)
56+
57+
# True positives = matched reference samples
58+
self.tp = len(detected_ref_inds)
59+
# False positives = extra test samples not matched
60+
self.fp = self.n_test - self.tp
61+
# False negatives = undetected reference samples
62+
self.fn = self.n_ref - self.tp
63+
# No tn attribute
64+
65+
self.specificity = self.tp / self.n_ref
66+
self.positive_predictivity = self.tp / self.n_test
67+
self.false_positive_rate = self.fp / self.n_test
4968

5069

70+
def compare(self):
5171

5272
test_samp_num = 0
5373
ref_samp_num = 0
5474

55-
while ref_samp_num < n_ref:
75+
# Why can't this just be a for loop of ref_samp_num?
76+
while ref_samp_num < n_ref and test_samp_num < n_test:
5677

5778
closest_samp_num, smallest_samp_diff = (
58-
self.get_closest_samp_num(ref_samp_num, test_samp_num))
79+
self.get_closest_samp_num(ref_samp_num, test_samp_num,
80+
self.n_test))
81+
# This needs to work for last index
5982
closest_samp_num_next, smallest_samp_diff_next = (
60-
self.get_closest_samp_num(ref_samp_num + 1, test_samp_num))
83+
self.get_closest_samp_num(ref_samp_num + 1, test_samp_num,
84+
self.n_test))
6185

6286
# Found a contested test sample number. Decide which reference
6387
# sample it belongs to.
6488
if closest_samp_num == closest_samp_num_next:
65-
pass
66-
# No clash. Assign the reference-test pair
67-
else:
89+
# If the sample is closer to the next reference sample, get
90+
# the next closest sample for this reference sample.
91+
if smallest_samp_diff_next < smallest_samp_diff:
92+
# Get the next closest sample.
93+
# Can this be empty? Need to catch case where nothing left?
94+
closest_samp_num, smallest_samp_diff = (
95+
self.get_closest_samp_num(ref_samp_num, test_samp_num,
96+
closest_samp_num))
97+
98+
6899
self.matching_sample_nums[ref_samp_num] = closest_samp_num
69100

70-
ref_samp_num += 1
71-
test_samp_num = closest_samp_num + 1
101+
# If no clash, it is straightforward.
102+
103+
# Assign the reference-test pair if close enough
104+
if smallest_sample_diff < self.window_width:
105+
self.matching_sample_nums[ref_samp_num] = closest_samp_num
72106

107+
ref_samp_num += 1
108+
test_samp_num = closest_samp_num + 1
73109

74110
self.calc_stats()
75111

76112

77-
def get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
113+
def get_closest_samp_num(self, ref_samp_num, start_test_samp_num,
114+
stop_test_samp_num):
78115
"""
79116
Return the closest testing sample number for the given reference
80-
sample number. Begin the search from start_test_samp_num.
117+
sample number. Limit the search between start_test_samp_num and
118+
stop_test_samp_num.
81119
"""
82120

83121
if start_test_samp_num >= self.n_test:
@@ -92,7 +130,7 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
92130
smallest_samp_diff = abs(samp_diff)
93131

94132
# Iterate through the testing samples
95-
for test_samp_num in range(start_test_samp_num, self.n_test):
133+
for test_samp_num in range(start_test_samp_num, stop_test_samp_num):
96134
test_samp = self.test_sample[test_samp_num]
97135
samp_diff = ref_samp - test_samp
98136
abs_samp_diff = abs(samp_diff)
@@ -109,32 +147,22 @@ def get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
109147
return closest_samp_num, smallest_samp_diff
110148

111149

112-
113-
114-
115-
116-
117-
def compare_annotations(ind_ref, ind_comp):
150+
def compare_annotations(ref_sample, test_sample, window_width):
118151
"""
152+
153+
Parameters
154+
----------
119155
120-
"""
121-
122-
123-
detected_inds
124-
missed_inds
125-
126-
127-
tp
128-
tn
129-
fp
130-
fn
156+
Returns
157+
-------
158+
comparitor : Comparitor object
159+
Object containing parameters about the two sets of annotations
131160
132-
tpr
133-
tnr
134-
fpr
135-
fnr
161+
"""
162+
comparitor = Comparitor(ref_sample, test_sample, window_width)
163+
comparitor.compare()
136164

137-
return evaluation
165+
return comparitor
138166

139167

140168
def plot_comparitor(comparitor, sig=None):

0 commit comments

Comments
 (0)