@@ -60,9 +60,10 @@ def plot_items(signal=None, ann_samp=None, ann_sym=None, fs=None,
60
60
title : str, optional
61
61
The title of the graph.
62
62
sig_style : list, optional
63
- A list of strings, specifying the style of the matplotlib plot for each
64
- signal channel. If the list has a length of 1, the style will be used
65
- for all channels.
63
+ A list of strings, specifying the style of the matplotlib plot
64
+ for each signal channel. The list length should match the number
65
+ of signal channels. If the list has a length of 1, the style
66
+ will be used for all channels.
66
67
ann_style : list, optional
67
68
A list of strings, specifying the style of the matplotlib plot for each
68
69
annotation channel. If the list has a length of 1, the style will be
@@ -350,7 +351,6 @@ def plot_wfdb(record=None, annotation=None, plot_sym=False,
350
351
- the sampling frequency, from the `fs` attribute if present, and if fs
351
352
was not already extracted from the `record` argument.
352
353
353
-
354
354
Parameters
355
355
----------
356
356
record : wfdb Record, optional
@@ -370,9 +370,10 @@ def plot_wfdb(record=None, annotation=None, plot_sym=False,
370
370
of signal channels. If the list has a length of 1, the style
371
371
will be used for all channels.
372
372
ann_style : list, optional
373
- A list of strings, specifying the style of the matplotlib plot for each
374
- annotation channel. If the list has a length of 1, the style will be
375
- used for all channels.
373
+ A list of strings, specifying the style of the matplotlib plot
374
+ for each annotation channel. The list length should match the
375
+ number of annotation channels. If the list has a length of 1,
376
+ the style will be used for all channels.
376
377
ecg_grids : list, optional
377
378
A list of integers specifying channels in which to plot ecg grids. May
378
379
also be set to 'all' for all channels. Major grids at 0.5mV, and minor
@@ -400,14 +401,14 @@ def plot_wfdb(record=None, annotation=None, plot_sym=False,
400
401
figsize=(10,4), ecg_grids='all')
401
402
402
403
"""
403
- (signal , ann_samp , ann_sym , fs , sig_name ,
404
- sig_units , record_name ) = get_wfdb_plot_items (record = record ,
405
- annotation = annotation ,
406
- plot_sym = plot_sym )
404
+ (signal , ann_samp , ann_sym , fs ,
405
+ ylabel , record_name ) = get_wfdb_plot_items (record = record ,
406
+ annotation = annotation ,
407
+ plot_sym = plot_sym )
407
408
408
409
return plot_items (signal = signal , ann_samp = ann_samp , ann_sym = ann_sym , fs = fs ,
409
- time_units = time_units , sig_name = sig_name ,
410
- sig_units = sig_units , title = (title or record_name ),
410
+ time_units = time_units , ylabel = ylabel ,
411
+ title = (title or record_name ),
411
412
sig_style = sig_style ,
412
413
ann_style = ann_style , ecg_grids = ecg_grids ,
413
414
figsize = figsize , return_fig = return_fig )
@@ -430,31 +431,27 @@ def get_wfdb_plot_items(record, annotation, plot_sym):
430
431
sig_name = record .sig_name
431
432
sig_units = record .units
432
433
record_name = 'Record: %s' % record .record_name
434
+ ylabel = ['/' .join (pair ) for pair in zip (sig_name , sig_units )]
433
435
else :
434
- signal = fs = sig_name = sig_units = record_name = None
436
+ signal = fs = ylabel = record_name = None
435
437
436
438
# Get annotation attributes
437
439
if annotation :
438
- # Note: There may be instances in which the annotation `chan`
439
- # attribute has non-overlapping channels with the signal.
440
- # In this case, omit empty middle channels.
441
-
442
440
# Get channels
443
- all_chans = set (annotation .chan )
444
-
445
- n_chans = max (all_chans ) + 1
441
+ ann_chans = set (annotation .chan )
442
+ n_ann_chans = max (ann_chans ) + 1
446
443
447
444
# Indices for each channel
448
- chan_inds = n_chans * [np .empty (0 , dtype = 'int' )]
445
+ chan_inds = n_ann_chans * [np .empty (0 , dtype = 'int' )]
449
446
450
- for chan in all_chans :
447
+ for chan in ann_chans :
451
448
chan_inds [chan ] = np .where (annotation .chan == chan )[0 ]
452
449
453
450
ann_samp = [annotation .sample [ci ] for ci in chan_inds ]
454
451
455
452
if plot_sym :
456
- ann_sym = n_chans * [None ]
457
- for ch in all_chans :
453
+ ann_sym = n_ann_chans * [None ]
454
+ for ch in ann_chans :
458
455
ann_sym [ch ] = [annotation .symbol [ci ] for ci in chan_inds [ch ]]
459
456
else :
460
457
ann_sym = None
@@ -468,7 +465,49 @@ def get_wfdb_plot_items(record, annotation, plot_sym):
468
465
ann_samp = None
469
466
ann_sym = None
470
467
471
- return signal , ann_samp , ann_sym , fs , sig_name , sig_units , record_name
468
+ # Cleaning: remove empty channels and set labels and styles.
469
+
470
+ # Wrangle together the signal and annotation channels if necessary
471
+ if record and annotation :
472
+ # There may be instances in which the annotation `chan`
473
+ # attribute has non-overlapping channels with the signal.
474
+ # In this case, omit empty middle channels. This function should
475
+ # already process labels and arrangements before passing into
476
+ # `plot_items`
477
+ sig_chans = set (range (signal .shape [1 ]))
478
+ all_chans = sorted (sig_chans .union (ann_chans ))
479
+
480
+ # Need to update ylabels and annotation values
481
+ if sig_chans != all_chans :
482
+ compact_ann_samp = []
483
+ if plot_sym :
484
+ compact_ann_sym = []
485
+ else :
486
+ compact_ann_sym = None
487
+ ylabel = []
488
+ for ch in all_chans : # ie. 0, 1, 9
489
+ if ch in ann_chans :
490
+ compact_ann_samp .append (ann_samp [ch ])
491
+ if plot_sym :
492
+ compact_ann_sym .append (ann_sym [ch ])
493
+ if ch in sig_chans :
494
+ ylabel .append ('' .join ([sig_name [ch ], sig_units [ch ]]))
495
+ else :
496
+ ylabel .append ('ch_%d/NU' % ch )
497
+ ann_samp = compact_ann_samp
498
+ ann_sym = compact_ann_sym
499
+ # Signals encompass annotations
500
+ else :
501
+ ylabel = ['/' .join (pair ) for pair in zip (sig_name , sig_units )]
502
+
503
+ # Remove any empty middle channels from annotations
504
+ elif annotation :
505
+ ann_samp = [a for a in ann_samp if a .size ]
506
+ if ann_sym is not None :
507
+ ann_sym = [a for a in ann_sym if a ]
508
+ ylabel = ['ch_%d/NU' % ch for ch in ann_chans ]
509
+
510
+ return signal , ann_samp , ann_sym , fs , ylabel , record_name
472
511
473
512
474
513
def plot_all_records (directory = '' ):
0 commit comments