Skip to content

Commit 1df7e91

Browse files
committed
continue refactor
1 parent c29b431 commit 1df7e91

File tree

2 files changed

+132
-155
lines changed

2 files changed

+132
-155
lines changed

wfdb/io/_header.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,6 @@
108108
# Specifications of all wfdb header fields, except for comments
109109
FIELD_SPECS = pd.concat((RECORD_SPECS, SIGNAL_SPECS, SEGMENT_SPECS))
110110

111-
# Allowed types of wfdb header fields, and also attributes defined in
112-
# this library
113-
ALLOWED_TYPES = dict([[index, FIELD_SPECS.loc[index, 'allowed_types']] for index in FIELD_SPECS.index])
114-
ALLOWED_TYPES.add({'comment': (str,), p_signal, d_signal, e_p_signal, e_d_signal, segments})
115-
116-
# Fields that must be lists
117-
LIST_FIELDS = tuple(SIGNAL_SPECS.index) + ('comments', 'e_p_signal',
118-
'e_d_signal', 'segments')
119-
120111

121112
# Regexp objects for reading headers
122113

wfdb/io/record.py

Lines changed: 132 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from . import download
1919

2020
import pdb
21+
22+
2123
class BaseRecord(object):
2224
# The base WFDB class extended by the Record and MultiRecord classes.
2325
def __init__(self, record_name=None, n_sig=None,
@@ -35,7 +37,7 @@ def __init__(self, record_name=None, n_sig=None,
3537
self.comments = comments
3638
self.sig_name = sig_name
3739

38-
def check_field(self, field, required_channels=None):
40+
def check_field(self, field, required_channels='all'):
3941
"""
4042
Check whether a single field is valid in its basic form. Does
4143
not check compatibility with other fields.
@@ -48,7 +50,11 @@ def check_field(self, field, required_channels=None):
4850
Used for signal specification fields. Species the channels
4951
to check. Other channels can be None.
5052
51-
Be aware that this function is not just called from wrheader.
53+
Notes
54+
-----
55+
This function is called from wrheader to check fields before
56+
writing. It is also supposed to be usable at any point to
57+
check a specific field.
5258
5359
"""
5460
item = getattr(self, field)
@@ -61,21 +67,15 @@ def check_field(self, field, required_channels=None):
6167
# We should have a list specifying these automatically.
6268

6369
# Whether the item should be a list. Watch out for required_channels for `segments`
64-
expect_list = True if field in _header.LIST_FIELDS else False
70+
expect_list = True if field in LIST_FIELDS else False
6571

6672
# Check the type of the field (and of its elements if it should
6773
# be a list)
68-
check_item_type(item, field_name=field,
69-
allowed_types=_header.ALLOWED_TYPES[field],
74+
_check_item_type(item, field_name=field,
75+
allowed_types=ALLOWED_TYPES[field],
7076
expect_list=expect_list,
7177
required_channels=required_channels)
7278

73-
74-
#
75-
# self.check_field_type(field, channels)
76-
77-
78-
7979
# Individual specific field checks
8080

8181
if field in ['d_signal', 'p_signal']:
@@ -87,8 +87,6 @@ def check_field(self, field, required_channels=None):
8787
ndim=1, parent_class=(lambda f: np.integer if f == 'e_d_signal' else np.float64)(field),
8888
channel_num=ch)
8989

90-
#elif field == 'segments': # Nothing to check here.
91-
9290
# Record specification fields
9391

9492
elif field == 'record_name':
@@ -122,16 +120,15 @@ def check_field(self, field, required_channels=None):
122120
elif field == 'base_date':
123121
_ = datetime.datetime.strptime(self.base_date, '%d/%m/%Y')
124122

125-
# Lists of elements to check.
126-
elif expect_list:
123+
# Signal specification fields
124+
elif field in SIGNAL_SPECS.index:
127125

128126
for ch in range(len(item)):
129-
# The channel element is allowed to be None
127+
# If the element is allowed to be None
130128
if ch not in required_channels:
131129
if item[ch] is None:
132130
continue
133131

134-
# Signal specification fields.
135132
if field == 'file_name':
136133
# Check for file_name characters
137134
accepted_string = re.match('[-\w]+\.?[\w]+', item[ch])
@@ -175,9 +172,12 @@ def check_field(self, field, required_channels=None):
175172
if len(set(item)) != len(item):
176173
raise ValueError('sig_name strings must be unique.')
177174

178-
# Segment specification fields
179-
elif field == 'seg_name':
180-
# Segment names must be alphanumerics or just a single '~'
175+
# Segment specification fields and comments
176+
elif field in SEGMENT_SPECS.index:
177+
for ch in range(len(item)):
178+
if field == 'seg_name':
179+
# Segment names must be alphanumerics or just a
180+
# single '~'
181181
if item[ch] == '~':
182182
continue
183183
accepted_string = re.match('[-\w]+', item[ch])
@@ -192,54 +192,19 @@ def check_field(self, field, required_channels=None):
192192
raise ValueError('seg_len values must be positive integers. Only seg_len[0] may be 0 to indicate a layout segment')
193193
# Comment field
194194
elif field == 'comments':
195-
# Allow empty string comment lines
196-
if item[ch] =='':
197-
continue
198195
if item[ch].startswith('#'):
199196
print("Note: comment strings do not need to begin with '#'. This library adds them automatically.")
200197
if re.search('[\t\n\r\f\v]', item[ch]):
201198
raise ValueError('comments may not contain tabs or newlines (they may contain spaces and underscores).')
202199

203200

204-
def check_field_type(self, field, ch=None):
205-
"""
206-
Check the data type of the specified field.
207-
ch is used for signal specification fields
208-
Some fields are lists. This must be checked, along with their elements.
209-
"""
210-
item = getattr(self, field)
211-
212-
# Record specification field. Nonlist.
213-
if field in _header.RECORD_SPECS:
214-
check_item_type(item, field, _header.RECORD_SPECS[field].allowed_types)
215-
216-
# Signal specification field. List.
217-
elif field in _header.SIGNAL_SPECS:
218-
check_item_type(item, field, _header.SIGNAL_SPECS[field].allowed_types, ch)
219-
220-
# Segment specification field. List. All elements cannot be None
221-
elif field in _header.SEGMENT_SPECS:
222-
check_item_type(item, field, _header.SEGMENT_SPECS[field].allowed_types, 'all')
223-
224-
# Comments field. List. Elements cannot be None
225-
elif field == 'comments':
226-
check_item_type(item, field, (str), 'all')
227-
228-
# Signals field.
229-
elif field in ['p_signal','d_signal']:
230-
check_item_type(item, field, (np.ndarray))
231-
232-
elif field in ['e_p_signal', 'e_d_signal']:
233-
check_item_type(item, field, (np.ndarray), 'all')
234-
235-
# Segments field. List. Elements may be None.
236-
elif field == 'segments':
237-
check_item_type(item, field, (Record), 'none')
238-
239-
240201
def check_read_inputs(self, sampfrom, sampto, channels, physical,
241202
smooth_frames, return_res):
242-
# Ensure that input read parameters are valid for the record
203+
"""
204+
Ensure that input read parameters (from rdsamp) are valid for
205+
the record
206+
207+
"""
243208

244209
# Data Type Check
245210
if not hasattr(sampfrom, '__index__'):
@@ -263,9 +228,9 @@ def check_read_inputs(self, sampfrom, sampto, channels, physical,
263228
raise ValueError('sampto must be greater than sampfrom')
264229

265230
# Channel Ranges
266-
if min(c) < 0:
231+
if min(channels) < 0:
267232
raise ValueError('Input channels must all be non-negative integers')
268-
if max(c) > self.n_sig - 1:
233+
if max(channels) > self.n_sig - 1:
269234
raise ValueError('Input channels must all be lower than the total number of channels')
270235

271236
if return_res not in [64, 32, 16, 8]:
@@ -279,89 +244,6 @@ def check_read_inputs(self, sampfrom, sampto, channels, physical,
279244
raise ValueError('This package version cannot expand all samples when reading multi-segment records. Must enable frame smoothing.')
280245

281246

282-
def check_item_type(item, field_name, allowed_types, expect_list=False,
283-
required_channels='all'):
284-
"""
285-
Check the item's type against a set of allowed types.
286-
Vary the print message regarding whether the item can be None.
287-
Helper to `BaseRecord.check_field`.
288-
289-
Parameters
290-
----------
291-
item : any
292-
The item to check.
293-
field_name : str
294-
The field name.
295-
allowed_types : iterable
296-
Iterable of types the item is allowed to be.
297-
expect_list : bool, optional
298-
Whether the item is expected to be a list.
299-
required_channels : list, optional
300-
List of integers specifying which channels of the item must be
301-
present. May be set to 'all' to indicate all channels. Only used
302-
if `expect_list` is True, ie. item is a list, and its
303-
subelements are to be checked.
304-
305-
"""
306-
if expect_list:
307-
if not isinstance(item, list):
308-
raise TypeError('Field `%s` must be a list.' % field_name)
309-
310-
# All channels of the field must be present.
311-
if required_channels == 'all':
312-
required_channels = list(range(len(item)))
313-
314-
for ch in range(len(item)):
315-
# Check whether the field may be None
316-
if ch in required_channels:
317-
allowed_types_ch = allowed_types + (type(None),)
318-
else:
319-
allowed_types_ch = allowed_types
320-
321-
if not isinstance(item[ch], allowed_type):
322-
raise TypeError('Channel %d of field `%s` must be one of the following types:' % (ch, field_name),
323-
allowed_types)
324-
else:
325-
if not isinstance(item, allowed_types):
326-
raise TypeError('Field `%s` must be one of the following types:',
327-
allowed_types)
328-
329-
330-
def check_np_array(item, field_name, ndim, parent_class, channel_num=None):
331-
"""
332-
Check a numpy array's shape and dtype against required
333-
specifications.
334-
335-
Parameters
336-
----------
337-
item : numpy array
338-
The numpy array to check
339-
field_name : str
340-
The name of the field to check
341-
ndim : int
342-
The required number of dimensions
343-
parent_class : type
344-
The parent class of the dtype. ie. np.integer, np.float64.
345-
channel_num : int, optional
346-
If not None, indicates that the item passed in is a subelement
347-
of a list. Indicate this in the error message if triggered.
348-
349-
"""
350-
# Check shape
351-
if item.ndim != ndim:
352-
error_msg = 'Field `%s` must have ndim == %d' % (field_name, ndim)
353-
if channel_num is not None:
354-
error_msg = ('Channel %d of f' % ) + error_msg[1:]
355-
raise TypeError(error_msg)
356-
357-
# Check dtype
358-
if not np.issubdtype(item.dtype, parent_class):
359-
error_msg = 'Field `%s` must have a dtype that subclasses %s' (field_name, parent_class)
360-
if channel_num is not None:
361-
error_msg = ('Channel %d of f' % ) + error_msg[1:]
362-
raise TypeError(error_msg)
363-
364-
365247
class Record(BaseRecord, _header.HeaderMixin, _signal.SignalMixin):
366248
"""
367249
The class representing single segment WFDB records.
@@ -854,8 +736,112 @@ def multi_to_single(self, physical, return_res=64):
854736

855737
return record
856738

739+
# ---------------------- Type Specifications ------------------------- #
740+
741+
742+
# Allowed types of wfdb header fields, and also attributes defined in
743+
# this library
744+
ALLOWED_TYPES = dict([[index, _header.FIELD_SPECS.loc[index, 'allowed_types']] for index in _header.FIELD_SPECS.index])
745+
ALLOWED_TYPES.update({'comment': (str,), 'p_signal': (np.ndarray,),
746+
'd_signal':(np.ndarray,), 'e_p_signal':(np.ndarray,),
747+
'e_d_signal':(np.ndarray,),
748+
'segments':(Record, type(None))})
749+
750+
# Fields that must be lists
751+
LIST_FIELDS = tuple(_header.SIGNAL_SPECS.index) + ('comments', 'e_p_signal',
752+
'e_d_signal', 'segments')
753+
754+
755+
def _check_item_type(item, field_name, allowed_types, expect_list=False,
756+
required_channels='all'):
757+
"""
758+
Check the item's type against a set of allowed types.
759+
Vary the print message regarding whether the item can be None.
760+
Helper to `BaseRecord.check_field`.
761+
762+
Parameters
763+
----------
764+
item : any
765+
The item to check.
766+
field_name : str
767+
The field name.
768+
allowed_types : iterable
769+
Iterable of types the item is allowed to be.
770+
expect_list : bool, optional
771+
Whether the item is expected to be a list.
772+
required_channels : list, optional
773+
List of integers specifying which channels of the item must be
774+
present. May be set to 'all' to indicate all channels. Only used
775+
if `expect_list` is True, ie. item is a list, and its
776+
subelements are to be checked.
777+
778+
Notes
779+
-----
780+
This is called by `check_field`, which determines whether the item
781+
should be a list or not. This function should generally not be
782+
called by the user directly.
783+
784+
"""
785+
if expect_list:
786+
if not isinstance(item, list):
787+
raise TypeError('Field `%s` must be a list.' % field_name)
788+
789+
# All channels of the field must be present.
790+
if required_channels == 'all':
791+
required_channels = list(range(len(item)))
792+
793+
for ch in range(len(item)):
794+
# Check whether the field may be None
795+
if ch in required_channels:
796+
allowed_types_ch = allowed_types + (type(None),)
797+
else:
798+
allowed_types_ch = allowed_types
799+
800+
if not isinstance(item[ch], allowed_type):
801+
raise TypeError('Channel %d of field `%s` must be one of the following types:' % (ch, field_name),
802+
allowed_types)
803+
else:
804+
if not isinstance(item, allowed_types):
805+
raise TypeError('Field `%s` must be one of the following types:',
806+
allowed_types)
807+
808+
809+
def check_np_array(item, field_name, ndim, parent_class, channel_num=None):
810+
"""
811+
Check a numpy array's shape and dtype against required
812+
specifications.
813+
814+
Parameters
815+
----------
816+
item : numpy array
817+
The numpy array to check
818+
field_name : str
819+
The name of the field to check
820+
ndim : int
821+
The required number of dimensions
822+
parent_class : type
823+
The parent class of the dtype. ie. np.integer, np.float64.
824+
channel_num : int, optional
825+
If not None, indicates that the item passed in is a subelement
826+
of a list. Indicate this in the error message if triggered.
827+
828+
"""
829+
# Check shape
830+
if item.ndim != ndim:
831+
error_msg = 'Field `%s` must have ndim == %d' % (field_name, ndim)
832+
if channel_num is not None:
833+
error_msg = ('Channel %d of f' % channel_num) + error_msg[1:]
834+
raise TypeError(error_msg)
835+
836+
# Check dtype
837+
if not np.issubdtype(item.dtype, parent_class):
838+
error_msg = 'Field `%s` must have a dtype that subclasses %s' (field_name, parent_class)
839+
if channel_num is not None:
840+
error_msg = ('Channel %d of f' % channel_num) + error_msg[1:]
841+
raise TypeError(error_msg)
842+
857843

858-
#------------------- Reading Records -------------------#
844+
#------------------------- Reading Records --------------------------- #
859845

860846
def rdheader(record_name, pb_dir=None, rd_segments=False):
861847
"""

0 commit comments

Comments
 (0)