Skip to content

Commit 9afa121

Browse files
committed
refactor helpers for multi to single and allow digital input
1 parent 23ce8e7 commit 9afa121

File tree

1 file changed

+35
-53
lines changed

1 file changed

+35
-53
lines changed

wfdb/io/record.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -650,13 +650,13 @@ def _get_required_channels(self, seg_numbers, channels, dirname, pb_dir):
650650
for i in range(0, len(seg_numbers)):
651651
# Skip empty segments
652652
if self.seg_name[seg_numbers[i]] == '~':
653-
required_channels.append(None)
653+
required_channels.append([])
654654
else:
655655
# Get the signal names of the current segment
656656
s_sig_names = rdheader(
657657
os.path.join(dirname, self.seg_name[seg_numbers[i]]),
658658
pb_dir=pb_dir).sig_name
659-
required_channels.append(get_wanted_channel_inds(
659+
required_channels.append(get_wanted_channels(
660660
w_sig_names, s_sig_names))
661661

662662
return required_channels
@@ -722,8 +722,8 @@ def multi_to_single(self, physical, return_res=64):
722722
del(fields[attr])
723723

724724
# Get the formats, signal names and units from the first segment
725-
for attr in ['fmt', 'sig_name', 'units']:
726-
setattr(fields, attr, getattr(self.segments[0], attr))
725+
for attr in ['fmt', 'adcgain', 'baseline', 'units', 'sig_name']:
726+
fields[attr] = getattr(self.segments[0], attr)
727727

728728
# Figure out attribute to set, and dtype.
729729
if physical:
@@ -734,27 +734,26 @@ def multi_to_single(self, physical, return_res=64):
734734
dtype = 'float32'
735735
else:
736736
dtype = 'float16'
737+
nan_vals = self.n_sig * [np.nan]
737738
else:
738739
# Figure out if this conversion can be performed. All
739740
# signals must have the same fmt, gain, and baseline for
740741
# all segments. Fixed layout signals automatically
741-
# pass the test. Also get the input output mapping for each
742-
# segment.
742+
# pass the test.
743743
if self.layout == 'variable':
744-
channel_map = (self.n_seg-1) * []
745-
746744
for seg in self.segments[1:]:
747-
if seg is None:
748-
continue
749-
750-
751-
752-
if not_the_same:
753-
raise Exception('This variable layout multi-segment record cannot be converted to single segment, in digital format')
754-
745+
segment_channels = get_wanted_channels(fields['sig_name'], seg, pad=True)
746+
for attr in ['fmt', 'adcgain', 'baseline', 'units', 'sig_name']:
747+
for ch in range(self.n_sig):
748+
# Skip if the signal is not contained in the segment
749+
if segment_channels[ch] is None:
750+
continue
751+
if getattr(seg, attr)[segment_channels[ch]] != fields[attr][ch]:
752+
raise Exception('This variable layout multi-segment record cannot be converted to single segment, in digital format.')
755753

756754
sig_attr = 'd_signal'
757755
dtype = ???
756+
nan_vals = _signal.digi_nan(fields['fmt'])
758757

759758
combined_signal = np.zeros([self.sig_len, self.n_sig], dtype=dtype)
760759

@@ -763,50 +762,33 @@ def multi_to_single(self, physical, return_res=64):
763762
start_samps = [0] + list(np.cumsum(self.seg_len)[0:-1])
764763
end_samps = list(np.cumsum(self.seg_len))
765764

766-
767-
768765
if self.layout == 'fixed':
769766
# Copy over the signals directly. Recall there are no
770767
# empty segments in fixed layout records.
771768
for i in range(self.n_seg):
772-
combined_signal[start_samps[i]:end_samps[i],:] = getattr(self.segments[i], sig_attr)
769+
combined_signal[start_samps[i]:end_samps[i], :] = getattr(self.segments[i], sig_attr)
773770
else:
774771
# Copy over the signals into the matching channels
775772
for i in range(1, self.n_seg):
776773
seg = self.segments[i]
777774

778775
# Empty segment
779776
if seg is None:
780-
combined_signal[start_samps[i]:end_samps[i], :] = np.nan #### Or digital values .............
781-
782-
777+
combined_signal[start_samps[i]:end_samps[i], :] = nan_vals
783778
# Non-empty segment
784779
else:
785-
# Figure out if there are any channels wanted and
786-
# the output channels they are to be stored in
787-
inchannels = []
788-
outchannels = []
789-
for s in fields['sig_name']:
790-
if s in seg.sig_name:
791-
inchannels.append(seg.sig_name.index(s))
792-
outchannels.append(fields['sig_name'].index(s))
793-
794-
# Segment contains no wanted channels. Fill with nans.
795-
if inchannels == []:
796-
p_signal[startsamps[i]:endsamps[i],:] = np.nan
797-
# Segment contains wanted channel(s). Transfer samples.
798-
else:
799-
# This statement is necessary in case this function is not called
800-
# directly from rdsamp with m2s=True.
801-
if not hasattr(seg, 'p_signal'):
802-
seg.p_signal = seg.dac(return_res=return_res)
803-
804-
for ch in range(0, fields['n_sig']):
805-
if ch not in outchannels:
806-
p_signal[startsamps[i]:endsamps[i],ch] = np.nan
807-
else:
808-
p_signal[startsamps[i]:endsamps[i],ch] = seg.p_signal[:, inchannels[outchannels.index(ch)]]
809-
780+
# Get the segment channels to copy over for each
781+
# overall channel
782+
segment_channels = get_wanted_channels(fields['sig_name'],
783+
seg, pad=True)
784+
for ch in range(self.n_sig):
785+
# Fill with invalids if segment does not contain
786+
# signal
787+
if segment_channels[ch] is None:
788+
combined_signal[start_samps[i]:end_samps[i], ch] = nan_vals[ch]
789+
# Copy over relevant signal
790+
else:
791+
combined_signal[start_samps[i]:end_samps[i], ch] = getattr(seg, sig_attr)
810792

811793
# Create the single segment Record object and set attributes
812794
record = Record()
@@ -1075,13 +1057,13 @@ def rdrecord(record_name, sampfrom=0, sampto='end', channels='all',
10751057

10761058
# Read the desired samples in the relevant segments
10771059
for i in range(len(seg_numbers)):
1078-
segnum = seg_numbers[i]
1060+
seg_num = seg_numbers[i]
10791061
# Empty segment or segment with no relevant channels
1080-
if record.seg_name[segnum] == '~' or seg_channels[i] is None:
1081-
record.segments[segnum] = None
1062+
if record.seg_name[seg_num] == '~' or len(seg_channels[i]) == 0:
1063+
record.segments[seg_num] = None
10821064
else:
1083-
record.segments[segnum] = rdrecord(
1084-
os.path.join(dirname, record.seg_name[segnum]),
1065+
record.segments[seg_num] = rdrecord(
1066+
os.path.join(dirname, record.seg_name[seg_num]),
10851067
sampfrom=seg_ranges[i][0], sampto=seg_ranges[i][1],
10861068
channels=seg_channels[i], physical=physical, pb_dir=pb_dir)
10871069

@@ -1171,7 +1153,7 @@ def rdsamp(record_name, sampfrom=0, sampto='end', channels='all', pb_dir=None):
11711153
return signals, fields
11721154

11731155

1174-
def get_wanted_channel_inds(wanted_sig_names, record_sig_names, pad=False):
1156+
def get_wanted_channels(wanted_sig_names, record_sig_names, pad=False):
11751157
"""
11761158
Given some wanted signal names, and the signal names contained in a
11771159
record, return the indices of the record channels that intersect.

0 commit comments

Comments
 (0)