Skip to content

Commit c13e27c

Browse files
committed
use correct dtype in multi_to_single
1 parent e531dfa commit c13e27c

File tree

3 files changed

+54
-87
lines changed

3 files changed

+54
-87
lines changed

tests/test_record.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -473,50 +473,33 @@ def test_5e(self):
473473

474474
def test_5f(self):
475475
"""
476-
477-
Gotta write this
478-
479-
480-
Multi-segment, variable layout, entire signal, digital
481-
482-
The reference signal creation cannot be made with rdsamp
483-
directly because the wfdb c package (10.5.24) applies the single
484-
adcgain and baseline values from the layout specification
485-
header, which is undesired in multi-segment signals with
486-
different adcgain/baseline values across segments.
476+
Multi-segment, variable layout, selected duration, selected
477+
channels, digital. There are two channels: PLETH, and II. Their
478+
fmt, adc_gain, and baseline do not change between the segments.
487479
488480
Target file created with:
489-
```
490-
for i in {01..18}
491-
do
492-
rdsamp -r sample-data/multi-segment/s25047/3234460_00$i -P | cut -f 2- >> record-5e
493-
done
494-
```
495-
481+
rdsamp -r sample-data/multi-segment/p000878/p000878-2137-10-26-16-57 -f s3550 -t s7500 -s 0 1 | cut -f 2- | perl -p -e 's/-32768/ -128/g;' > record-5f
496482
497483
"""
498-
sig, fields = wfdb.rdsamp('p000878-2137-10-26-16-57',
484+
record = wfdb.rdrecord('sample-data/multi-segment/p000878/p000878-2137-10-26-16-57',
485+
sampfrom=3550, sampto=7500, channels=[0,1],
486+
physical=False)
487+
sig = record.d_signal
488+
489+
# Compare data streaming from physiobank
490+
record_pb = wfdb.rdrecord('p000878-2137-10-26-16-57',
499491
pb_dir='mimic3wdb/matched/p00/p000878/',
500-
sampto=5000)
492+
sampfrom=3550, sampto=7500, channels=[0,1],
493+
physical=False)
501494
sig_target = np.genfromtxt('tests/target-output/record-5f')
502495

503496
np.testing.assert_equal(sig, sig_target)
497+
assert record.__eq__(record_pb)
504498

505499

506-
# Test 12 - Multi-segment variable layout/Selected duration/Selected Channels/Physical
507-
# Target file created with: rdsamp -r sample-data/multi-segment/s00001/s00001-2896-10-10-00-31 -f s -t 4000 -s 3 0 -P | cut -f 2- > target12
508-
#def test_12(self):
509-
# record=rdsamp('sample-data/multi-segment/s00001/s00001-2896-10-10-00-31', sampfrom=8750, sampto=500000)
510-
# sig_round = np.round(record.p_signal, decimals=8)
511-
# sig_target = np.genfromtxt('tests/target-output/target12')
512-
#
513-
# assert np.array_equal(sig, sig_target)
514-
515-
516-
# Cleanup written files
517500
@classmethod
518-
def tearDownClass(self):
519-
501+
def tearDownClass(cls):
502+
"Clean up written files"
520503
writefiles = ['03700181.dat','03700181.hea','100.atr','100.dat',
521504
'100.hea','1003.atr','100_3chan.dat','100_3chan.hea',
522505
'12726.anI','a103l.hea','a103l.mat','s0010_re.dat',

wfdb/io/_signal.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ def rd_segment(file_name, dirname, pb_dir, n_sig, fmt, sig_len, byte_offset,
743743
# Return uniform numpy array
744744
if smooth_frames or sum(samps_per_frame) == n_sig:
745745
# Figure out the largest required dtype for the segment to minimize memory usage
746-
max_dtype = np_dtype(fmt_res(fmt, maxres=True), discrete=True)
746+
max_dtype = np_dtype(fmt_res(fmt, max_res=True), discrete=True)
747747
# Allocate signal array. Minimize dtype
748748
signals = np.zeros([sampto-sampfrom, len(channels)], dtype=max_dtype)
749749

@@ -1332,32 +1332,38 @@ def wfdbfmt(res, single_fmt=True):
13321332
return '32'
13331333

13341334

1335-
def fmt_res(fmt, maxres=False):
1335+
def fmt_res(fmt, max_res=False):
13361336
"""
13371337
Return the resolution of the WFDB format(s).
1338+
1339+
Parameters
1340+
----------
1341+
fmt : str
1342+
The wfdb format. Can be a list of valid fmts. If it is a list,
1343+
and `max_res` is True, the list may contain None.
1344+
max_res : bool, optional
1345+
If given a list of fmts, whether to return the highest
1346+
resolution.
1347+
13381348
"""
13391349
if isinstance(fmt, list):
1340-
res = [fmt_res(f) for f in fmt]
1341-
if maxres is True:
1342-
res = np.max(res)
1350+
if max_res:
1351+
# Allow None
1352+
res = np.max([fmt_res(f) for f in fmt if f is not None])
1353+
else:
1354+
res = [fmt_res(f) for f in fmt]
13431355
return res
13441356

1345-
if fmt in ['8', '80']:
1346-
return 8
1347-
elif fmt in ['310', '311']:
1348-
return 10
1349-
elif fmt == '212':
1350-
return 12
1351-
elif fmt in ['16', '61']:
1352-
return 16
1353-
elif fmt == '24':
1354-
return 24
1355-
elif fmt == '32':
1356-
return 32
1357+
res = {'8':8, '80':8, '310':10, '311':10, '212':12, '16':16, '61':16,
1358+
'24':24, '32':32}
1359+
1360+
if fmt in res:
1361+
return res[fmt]
13571362
else:
13581363
raise ValueError('Invalid WFDB format.')
13591364

13601365

1366+
13611367
def np_dtype(res, discrete):
13621368
"""
13631369
Given the resolution of a signal, return the minimum

wfdb/io/record.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import multiprocessing
33
import posixpath
44
import re
5-
import warnings
65

76
import numpy as np
87
import os
@@ -662,15 +661,11 @@ def _arrange_fields(self, seg_numbers, seg_ranges, channels,
662661
# Update the layout specification segment. At this point it
663662
# should match the full original header
664663

665-
# Forcing the signal specifications to match the input
666-
# `channels` variable.
667-
if force_channels:
668-
669664
# Have to inspect existing channels of segments; requested
670665
# input channels will not be enough on its own because not
671666
# all signals may be present, depending on which section of
672667
# the signal was read.
673-
else:
668+
if not force_channels:
674669
# The desired signal names.
675670
desired_sig_names = [self.segments[0].sig_name[ch] for ch in channels]
676671
# Actual contained signal names of individual segments
@@ -686,15 +681,15 @@ def _arrange_fields(self, seg_numbers, seg_ranges, channels,
686681
item = getattr(self.segments[0], field)
687682
setattr(self.segments[0], field, [item[c] for c in channels])
688683

689-
self.segments[0].n_sig = self.n_sig = len(sig_name)
684+
self.segments[0].n_sig = self.n_sig = len(channels)
685+
if self.n_sig == 0:
686+
print('No signals of the desired channels are contained in the specified sample range.')
690687

691688
# Update record specification parameters
692689
self.sig_len = sum([sr[1]-sr[0] for sr in seg_ranges])
693690
self.n_seg = len(self.segments)
694691
self._adjust_datetime(sampfrom=sampfrom)
695692

696-
if self.n_sig == 0:
697-
warnings.warn('No signals of the desired channels are contained in the specified sample range.')
698693

699694
def multi_to_single(self, physical, return_res=64):
700695
"""
@@ -708,8 +703,8 @@ def multi_to_single(self, physical, return_res=64):
708703
physical : bool
709704
Whether to convert the physical or digital signal.
710705
return_res : int, optional
711-
The return resolution of the `p_signal` field. Options are 64, 32,
712-
and 16.
706+
The return resolution of the `p_signal` field. Options are:
707+
64, 32, and 16.
713708
714709
Returns
715710
-------
@@ -750,7 +745,7 @@ def multi_to_single(self, physical, return_res=64):
750745
# This will be the field dictionary to copy over.
751746
reference_fields = {'fmt':n_sig*[None], 'adc_gain':n_sig*[None],
752747
'baseline':n_sig*[None],
753-
'units':n_sig*[None], }
748+
'units':n_sig*[None]}
754749

755750
# For physical signals, mismatched fields will not be copied
756751
# over. For digital, mismatches will cause an exception.
@@ -786,20 +781,16 @@ def multi_to_single(self, physical, return_res=64):
786781
if physical:
787782
sig_attr = 'p_signal'
788783
# Figure out the largest required dtype
789-
dtype = _signal.np_dtype(_signal.fmt_res(fields['fmt'],
790-
maxres=True),
791-
discrete=False)
792-
nan_vals = self.n_sig * [np.nan]
784+
dtype = _signal.np_dtype(return_res, discrete=False)
785+
nan_vals = np.array([self.n_sig * [np.nan]], dtype=dtype)
793786
else:
794787
sig_attr = 'd_signal'
795788
# Figure out the largest required dtype
796-
dtype = _signal.np_dtype(_signal.fmt_res(fields['fmt'],
797-
maxres=True),
798-
discrete=True)
799-
nan_vals = _signal.digi_nan(fields['fmt'])
789+
dtype = _signal.np_dtype(return_res, discrete=True)
790+
nan_vals = np.array([_signal.digi_nan(fields['fmt'])], dtype=dtype)
800791

801-
# Create and set the full signal array
802-
combined_signal = np.zeros([self.sig_len, self.n_sig], dtype=dtype)
792+
# Initialize the full signal array
793+
combined_signal = np.repeat(nan_vals, self.sig_len, axis=0)
803794

804795
# Start and end samples in the overall array to place the
805796
# segment samples into
@@ -815,24 +806,15 @@ def multi_to_single(self, physical, return_res=64):
815806
# Copy over the signals into the matching channels
816807
for i in range(1, self.n_seg):
817808
seg = self.segments[i]
818-
819-
# Empty segment
820-
if seg is None:
821-
combined_signal[start_samps[i]:end_samps[i], :] = nan_vals
822-
# Non-empty segment
823-
else:
809+
if seg is not None:
824810
# Get the segment channels to copy over for each
825811
# overall channel
826812
segment_channels = get_wanted_channels(fields['sig_name'],
827813
seg.sig_name,
828814
pad=True)
829815
for ch in range(self.n_sig):
830-
# Fill with invalids if segment does not contain
831-
# signal
832-
if segment_channels[ch] is None:
833-
combined_signal[start_samps[i]:end_samps[i], ch] = nan_vals[ch]
834816
# Copy over relevant signal
835-
else:
817+
if segment_channels[ch] is not None:
836818
combined_signal[start_samps[i]:end_samps[i], ch] = getattr(seg, sig_attr)[:, segment_channels[ch]]
837819

838820
# Create the single segment Record object and set attributes
@@ -841,10 +823,6 @@ def multi_to_single(self, physical, return_res=64):
841823
setattr(record, field, fields[field])
842824
setattr(record, sig_attr, combined_signal)
843825

844-
# Consider the case where no signals are present. ie. only empty.
845-
# pdb.set_trace()
846-
# print(sig_attr)
847-
# print(getattr(record, sig_attr).dtype)
848826
# Use the signal to set record features
849827
if physical:
850828
record.set_p_features()

0 commit comments

Comments
 (0)