Skip to content

Commit baaa319

Browse files
committed
update clarity of benchmark_mitdb
1 parent 332c896 commit baaa319

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

wfdb/processing/evaluate.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def compare_annotations(ref_sample, test_sample, window_width, signal=None):
380380
return comparitor
381381

382382

383-
def benchmark_mitdb(detector, verbose=False):
383+
def benchmark_mitdb(detector, verbose=False, print_results=False):
384384
"""
385385
Benchmark a qrs detector against mitdb's records.
386386
@@ -390,11 +390,15 @@ def benchmark_mitdb(detector, verbose=False):
390390
The detector function.
391391
verbose : bool, optional
392392
The verbose option of the detector function.
393+
print_results : bool, optional
394+
Whether to print the overall performance, and the results for
395+
each record.
393396
394397
Returns
395398
-------
396-
comparitors : list
397-
List of Comparitor objects run on the records.
399+
comparitors : dictionary
400+
Dictionary of Comparitor objects run on the records, keyed on
401+
the record names.
398402
specificity : float
399403
Aggregate specificity.
400404
positive_predictivity : float
@@ -434,8 +438,18 @@ def benchmark_mitdb(detector, verbose=False):
434438
false_positive_rate = np.mean(
435439
[c.false_positive_rate for c in comparitors])
436440

441+
comparitors = dict(zip(record_list, comparitors))
442+
437443
print('Benchmark complete')
438444

445+
if print_results:
446+
print('\nOverall MITDB Performance - Specificity: %.4f, Positive Predictivity: %.4f, False Positive Rate: %.4f\n'
447+
% (specificity, positive_predictivity, false_positive_rate))
448+
for record_name in record_list:
449+
print('Record %s:' % record_name)
450+
comparitors[record_name].print_summary()
451+
print('\n\n')
452+
439453
return comparitors, specificity, positive_predictivity, false_positive_rate
440454

441455

@@ -451,5 +465,6 @@ def benchmark_mitdb_record(rec, detector, verbose):
451465
comparitor = compare_annotations(ref_sample=ann_ref.sample[1:],
452466
test_sample=qrs_inds,
453467
window_width=int(0.1 * fields['fs']))
454-
print('Finished record %s' % rec)
468+
if verbose:
469+
print('Finished record %s' % rec)
455470
return comparitor

wfdb/processing/qrs.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
2-
import numpy as np
2+
import pdb
33

4+
import numpy as np
45
from scipy import signal
56
from sklearn.preprocessing import normalize
67

@@ -146,7 +147,8 @@ def _set_conf(self):
146147

147148
def _bandpass(self, fc_low=5, fc_high=20):
148149
"""
149-
Apply a bandpass filter onto the signal, and save the filtered signal.
150+
Apply a bandpass filter onto the signal, and save the filtered
151+
signal.
150152
"""
151153
self.fc_low = fc_low
152154
self.fc_high = fc_high
@@ -162,24 +164,28 @@ def _bandpass(self, fc_low=5, fc_high=20):
162164

163165
def _mwi(self):
164166
"""
165-
Apply moving wave integration with a ricker (Mexican hat) wavelet onto
166-
the filtered signal, and save the square of the integrated signal.
167+
Apply moving wave integration (mwi) with a ricker (Mexican hat)
168+
wavelet onto the filtered signal, and save the square of the
169+
integrated signal.
167170
168171
The width of the hat is equal to the qrs width
169172
170-
Also find all local peaks in the mwi signal.
173+
After integration, find all local peaks in the mwi signal.
171174
"""
172-
b = signal.ricker(self.qrs_width, 4)
173-
self.sig_i = signal.filtfilt(b, [1], self.sig_f, axis=0) ** 2
175+
wavelet_filter = signal.ricker(self.qrs_width, 8)
174176

175-
# Save the mwi gain (x2 due to double filtering) and the total gain
176-
# from raw to mwi
177-
self.mwi_gain = get_filter_gain(b, [1],
177+
self.sig_i = signal.filtfilt(wavelet_filter, [1], self.sig_f,
178+
axis=0) ** 2
179+
180+
# Save the mwi gain (x2 due to double filtering) and the total
181+
# gain from raw to mwi
182+
self.mwi_gain = get_filter_gain(wavelet_filter, [1],
178183
np.mean([self.fc_low, self.fc_high]), self.fs) * 2
179184
self.transform_gain = self.filter_gain * self.mwi_gain
180185
self.peak_inds_i = find_local_peaks(self.sig_i, radius=self.qrs_radius)
181186
self.n_peaks_i = len(self.peak_inds_i)
182187

188+
183189
def _learn_init_params(self, n_calib_beats=8):
184190
"""
185191
Find a number of consecutive beats and use them to initialize:

0 commit comments

Comments
 (0)