Skip to content

Commit 761dbc0

Browse files
committed
finish plot_wfdb
1 parent 370b0be commit 761dbc0

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

wfdb/io/annotation.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,26 +1198,34 @@ def rdann(record_name, extension, sampfrom=0, sampto=None, shift_samps=False,
11981198
11991199
"""
12001200

1201-
return_label_elements = check_read_inputs(sampfrom, sampto, return_label_elements)
1201+
return_label_elements = check_read_inputs(sampfrom, sampto,
1202+
return_label_elements)
12021203

12031204
# Read the file in byte pairs
12041205
filebytes = load_byte_pairs(record_name, extension, pb_dir)
12051206

12061207
# Get wfdb annotation fields from the file bytes
1207-
sample, label_store, subtype, chan, num, aux_note = proc_ann_bytes(filebytes, sampto)
1208+
(sample, label_store, subtype,
1209+
chan, num, aux_note) = proc_ann_bytes(filebytes, sampto)
12081210

12091211
# Get the indices of annotations that hold definition information about
12101212
# the entire annotation file, and other empty annotations to be removed.
1211-
potential_definition_inds, rm_inds = get_special_inds(sample, label_store, aux_note)
1213+
potential_definition_inds, rm_inds = get_special_inds(sample, label_store,
1214+
aux_note)
12121215

12131216
# Try to extract information describing the annotation file
1214-
fs, custom_labels = interpret_defintion_annotations(potential_definition_inds, aux_note)
1217+
(fs,
1218+
custom_labels) = interpret_defintion_annotations(potential_definition_inds,
1219+
aux_note)
12151220

12161221
# Remove annotations that do not store actual sample and label information
1217-
sample, label_store, subtype, chan, num, aux_note = rm_empty_indices(rm_inds, sample, label_store, subtype, chan, num, aux_note)
1222+
(sample, label_store, subtype,
1223+
chan, num, aux_note) = rm_empty_indices(rm_inds, sample, label_store,
1224+
subtype, chan, num, aux_note)
12181225

1219-
# Convert lists to numpy arrays
1220-
sample, label_store, subtype, chan, num= lists_to_arrays(sample, label_store, subtype, chan, num)
1226+
# Convert lists to numpy arrays dtype='int'
1227+
(sample, label_store, subtype,
1228+
chan, num) = lists_to_int_arrays(sample, label_store, subtype, chan, num)
12211229

12221230
# Obtain annotation sample relative to the starting signal index
12231231
if shift_samps and len(sample) > 0 and sampfrom:
@@ -1232,11 +1240,12 @@ def rdann(record_name, extension, sampfrom=0, sampto=None, shift_samps=False,
12321240
pass
12331241

12341242
# Create the annotation object
1235-
annotation = Annotation(os.path.split(record_name)[1], extension, sample=sample, label_store=label_store,
1236-
subtype=subtype, chan=chan, num=num, aux_note=aux_note, fs=fs,
1243+
annotation = Annotation(record_name=os.path.split(record_name)[1],
1244+
extension=extension, sample=sample,
1245+
label_store=label_store, subtype=subtype,
1246+
chan=chan, num=num, aux_note=aux_note, fs=fs,
12371247
custom_labels=custom_labels)
12381248

1239-
12401249
# Get the set of unique label definitions contained in this annotation
12411250
if summarize_labels:
12421251
annotation.get_contained_labels(inplace=True)
@@ -1509,9 +1518,9 @@ def rm_empty_indices(*args):
15091518

15101519
return [[a[i] for i in keep_inds] for a in args[1:]]
15111520

1512-
def lists_to_arrays(*args):
1521+
def lists_to_int_arrays(*args):
15131522
"""
1514-
Convert lists to numpy arrays
1523+
Convert lists to numpy int arrays
15151524
"""
15161525
return [np.array(a, dtype='int') for a in args]
15171526

wfdb/plot/plot.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from ..io.annotation import Annotation
99

1010

11-
import pdb
12-
13-
1411
def plot_items(signal=None, ann_samp=None, ann_sym=None, fs=None,
1512
time_units='samples', sig_name=None, sig_units=None,
1613
ylabel=None, title=None, sig_style=[''], ann_style=['r*'],
@@ -296,7 +293,7 @@ def calc_ecg_grids(minsig, maxsig, sig_units, fs, maxt, time_units):
296293

297294

298295
def label_figure(axes, n_subplots, time_units, sig_name, sig_units, ylabel,
299-
title)
296+
title):
300297
"Add title, and axes labels"
301298
if title:
302299
axes[0].set_title(title)
@@ -391,8 +388,8 @@ def plot_wfdb(record=None, annotation=None, plot_sym=False,
391388
392389
"""
393390
(signal, ann_samp, ann_sym, fs, sig_name,
394-
sig_units) = get_wfdb_plot_items(record, annotation)
395-
391+
sig_units) = get_wfdb_plot_items(record=record, annotation=annotation,
392+
plot_sym=plot_sym)
396393

397394
return plot_items(signal=signal, ann_samp=ann_samp, ann_sym=ann_sym, fs=fs,
398395
time_units=time_units, sig_name=sig_name,
@@ -424,18 +421,33 @@ def get_wfdb_plot_items(record, annotation, plot_sym):
424421
if annotation:
425422
# Get channels
426423
all_chans = set(annotation.chan)
424+
425+
n_chans = max(all_chans) + 1
427426

428427
# Just one channel. Place content in one list index.
429-
if len(all_chans) == 1:
430-
ann_samp = annotation.chan[0]*[None] + [annotation.sample]
431-
if plot_sym:
432-
ann_sym = annotation.chan[0]*[None] + [annotation.symbol]
433-
else:
434-
ann_sym = None
435-
# Split annotations by channel
428+
# if len(all_chans) == 1:
429+
# ann_samp = annotation.chan[0]*[None] + [annotation.sample]
430+
# if plot_sym:
431+
# ann_sym = annotation.chan[0]*[None] + [annotation.symbol]
432+
# else:
433+
# ann_sym = None
434+
# # Split annotations by channel
435+
# else:
436+
437+
# Indices for each channel
438+
chan_inds = n_chans * [np.empty(0)]
439+
440+
for chan in all_chans:
441+
chan_inds[chan] = np.where(annotation.chan == chan)[0]
442+
443+
ann_samp = [annotation.sample[ci] for ci in chan_inds]
444+
445+
if plot_sym:
446+
ann_sym = n_chans * [None]
447+
for ch in all_chans:
448+
ann_sym[ch] = [annotation.symbol[ci] for ci in chan_inds[ch]]
436449
else:
437-
for chan in all_chans:
438-
pass
450+
ann_sym = None
439451

440452
# Try to get fs from annotation if not already in record
441453
if fs is None:
@@ -466,7 +478,7 @@ def plot_all_records(directory=os.getcwd()):
466478
records.sort()
467479

468480
for record_name in records:
469-
record = rdrecord(record_name)
481+
record = rdrecord(os.path.join(directory, record_name))
470482

471-
plot_wfdb(record, title='Record - %s' % record.recordname)
483+
plot_wfdb(record, title='Record - %s' % record.record_name)
472484
input('Press enter to continue...')

0 commit comments

Comments
 (0)