Skip to content

Commit e71482c

Browse files
committed
refactoring field check
1 parent 06ffb1b commit e71482c

File tree

3 files changed

+174
-119
lines changed

3 files changed

+174
-119
lines changed

wfdb/io/_header.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
int_types = (int, np.int64, np.int32, np.int16, np.int8)
1313
float_types = int_types + (float, np.float64, np.float32)
14-
int_dtypes = ('int64', 'uint64', 'int32', 'uint32','int16','uint16')
15-
1614

1715
"""
1816
WFDB field specifications for each field.
@@ -47,7 +45,6 @@
4745
to None if they are not written in, unless the fields are essential, in
4846
which case an actual default value will be set.
4947
50-
5148
The read vs write default values are different for 2 reasons:
5249
1. We want to force the user to be explicit with certain important
5350
fields when writing WFDB records fields, without affecting
@@ -144,19 +141,26 @@ class BaseHeaderMixin(object):
144141
MultiRecord classes
145142
"""
146143

147-
def get_write_subset(self, spec_fields):
144+
def get_write_subset(self, spec_type):
148145
"""
149-
Helper function for get_write_fields.
146+
Get the fields used to write the header,
147+
148+
149+
Helper function for `get_write_fields`.
150150
151-
- spec_fields is the set of specification fields
152-
For record specs, it returns a list of all fields needed.
153-
For signal specs, it returns a dictionary of all fields needed,
151+
Parameters
152+
----------
153+
spec_type : str
154+
The set of specification fields desired. Either 'record' or
155+
'signal'.
156+
157+
- For record fields, returns a list of all fields needed.
158+
- For signal fields, it returns a dictionary of all fields needed,
154159
with keys = field and value = list of 1 or 0 indicating channel for the field
155-
"""
156160
157-
# record specification fields
158-
if spec_fields == 'record':
159-
write_fields=[]
161+
"""
162+
if spec_type == 'record':
163+
write_fields = []
160164
fieldspecs = OrderedDict(reversed(list(rec_field_specs.items())))
161165
# Remove this requirement for single segs
162166
if not hasattr(self, 'n_seg'):
@@ -177,11 +181,11 @@ def get_write_subset(self, spec_fields):
177181
write_fields.append('comments')
178182

179183
# signal spec field. Need to return a potentially different list for each channel.
180-
elif spec_fields == 'signal':
184+
elif spec_type == 'signal':
181185
# List of lists for each channel
182-
write_fields=[]
186+
write_fields = []
183187

184-
allwrite_fields=[]
188+
allwrite_fields = []
185189
fieldspecs = OrderedDict(reversed(list(SIGNAL_FIELDS.items())))
186190

187191
for ch in range(self.n_sig):
@@ -198,7 +202,7 @@ def get_write_subset(self, spec_fields):
198202
# Add the field and its recursive dependencies
199203
while rf is not None:
200204
write_fieldsch.append(rf)
201-
rf=fieldspecs[rf].dependency
205+
rf = fieldspecs[rf].dependency
202206

203207
write_fields.append(write_fieldsch)
204208

@@ -244,46 +248,52 @@ def set_defaults(self):
244248
def wrheader(self, write_dir=''):
245249

246250
# Get all the fields used to write the header
247-
recwrite_fields, sigwrite_fields = self.get_write_fields()
251+
rec_write_fields, sig_write_fields = self.get_write_fields()
248252

249253
# Check the validity of individual fields used to write the header
250254

251255
# Record specification fields (and comments)
252-
for f in recwrite_fields:
253-
self.check_field(f)
256+
for field in rec_write_fields:
257+
self.check_field(field)
254258

255259
# Signal specification fields.
256-
for f in sigwrite_fields:
257-
self.check_field(f, sigwrite_fields[f])
260+
for field in sig_write_fields:
261+
self.check_field(field, channels=sig_write_fields[field])
258262

259263
# Check the cohesion of fields used to write the header
260-
self.check_field_cohesion(recwrite_fields, list(sigwrite_fields))
264+
self.check_field_cohesion(rec_write_fields, list(sig_write_fields))
261265

262266
# Write the header file using the specified fields
263-
self.wr_header_file(recwrite_fields, sigwrite_fields, write_dir)
267+
self.wr_header_file(rec_write_fields, sig_write_fields, write_dir)
264268

265269

266-
# Get the list of fields used to write the header. (Does NOT include d_signal or e_d_signal.)
267-
# Separate items by record and signal specification field.
268-
# Returns the default required fields, the user defined fields, and their dependencies.
269-
# recwrite_fields includes 'comment' if present.
270270
def get_write_fields(self):
271+
"""
272+
Get the list of fields used to write the header, separating
273+
record and signal specification fields.
274+
275+
Does NOT include `d_signal` or `e_d_signal`.
276+
277+
Returns the default required fields, the user defined fields, and their dependencies.
278+
rec_write_fields includes 'comment' if present.
279+
"""
271280

272281
# Record specification fields
273-
recwrite_fields=self.get_write_subset('record')
282+
rec_write_fields = self.get_write_subset('record')
274283

275284
# Add comments if any
276285
if self.comments != None:
277-
recwrite_fields.append('comments')
286+
rec_write_fields.append('comments')
278287

279-
# Determine whether there are signals. If so, get their required fields.
288+
# Determine whether there are signals. If so, get their required
289+
# fields.
280290
self.check_field('n_sig')
281-
if self.n_sig>0:
282-
sigwrite_fields=self.get_write_subset('signal')
291+
if self.n_sig > 0:
292+
sig_write_fields = self.get_write_subset('signal')
283293
else:
284-
sigwrite_fields = None
294+
sig_write_fields = None
285295

286-
return recwrite_fields, sigwrite_fields
296+
return rec_write_fields, sig_write_fields
287297

288298
# Set the object's attribute to its default value if it is missing
289299
# and there is a default. Not responsible for initializing the
@@ -320,14 +330,14 @@ def set_default(self, field):
320330
setattr(self, field, [SIGNAL_FIELDS[field].write_def]*self.n_sig)
321331

322332
# Check the cohesion of fields used to write the header
323-
def check_field_cohesion(self, recwrite_fields, sigwrite_fields):
333+
def check_field_cohesion(self, rec_write_fields, sig_write_fields):
324334

325335
# If there are no signal specification fields, there is nothing to check.
326336
if self.n_sig>0:
327337

328338
# The length of all signal specification fields must match n_sig
329339
# even if some of its elements are None.
330-
for f in sigwrite_fields:
340+
for f in sig_write_fields:
331341
if len(getattr(self, f)) != self.n_sig:
332342
raise ValueError('The length of field: '+f+' must match field n_sig.')
333343

@@ -354,7 +364,7 @@ def check_field_cohesion(self, recwrite_fields, sigwrite_fields):
354364

355365

356366

357-
def wr_header_file(self, recwrite_fields, sigwrite_fields, write_dir):
367+
def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
358368
# Write a header file using the specified fields
359369
header_lines=[]
360370

@@ -363,7 +373,7 @@ def wr_header_file(self, recwrite_fields, sigwrite_fields, write_dir):
363373
# Traverse the ordered dictionary
364374
for field in rec_field_specs:
365375
# If the field is being used, add it with its delimiter
366-
if field in recwrite_fields:
376+
if field in rec_write_fields:
367377
stringfield = str(getattr(self, field))
368378
# If fs is float, check whether it as an integer
369379
if field == 'fs' and isinstance(self.fs, float):
@@ -379,7 +389,7 @@ def wr_header_file(self, recwrite_fields, sigwrite_fields, write_dir):
379389
# Traverse the ordered dictionary
380390
for field in SIGNAL_FIELDS:
381391
# If the field is being used, add each of its elements with the delimiter to the appropriate line
382-
if field in sigwrite_fields and sigwrite_fields[field][ch]:
392+
if field in sig_write_fields and sig_write_fields[field][ch]:
383393
signallines[ch]=signallines[ch] + SIGNAL_FIELDS[field].delimiter + str(getattr(self, field)[ch])
384394
# The 'baseline' field needs to be closed with ')'
385395
if field== 'baseline':
@@ -388,7 +398,7 @@ def wr_header_file(self, recwrite_fields, sigwrite_fields, write_dir):
388398
header_lines = header_lines + signallines
389399

390400
# Create comment lines (if any)
391-
if 'comments' in recwrite_fields:
401+
if 'comments' in rec_write_fields:
392402
comment_lines = ['# '+comment for comment in self.comments]
393403
header_lines = header_lines + comment_lines
394404

wfdb/io/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def check_field(self, field):
220220
raise TypeError('The '+field+' field must be one of the following types:', ann_field_types[field])
221221

222222
if field in int_ann_fields:
223-
if item.dtype not in _header.int_dtypes:
223+
if not hasattr(field, '__index__'):
224224
raise TypeError('The '+field+' field must have an integer-based dtype.')
225225

226226
# Field specific checks

0 commit comments

Comments
 (0)