Skip to content

Commit c29b431

Browse files
committed
continue refactor
1 parent 21764cf commit c29b431

File tree

3 files changed

+97
-78
lines changed

3 files changed

+97
-78
lines changed

wfdb/io/_header.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,23 @@ class BaseHeaderMixin(object):
153153

154154
def get_write_subset(self, spec_type):
155155
"""
156-
Get the fields used to write the header,
157-
158-
159-
Helper function for `get_write_fields`.
156+
Get a set of fields used to write the header; either 'record'
157+
or 'signal' specification fields. Helper function for
158+
`get_write_fields`.
160159
161160
Parameters
162161
----------
163162
spec_type : str
164163
The set of specification fields desired. Either 'record' or
165164
'signal'.
166165
167-
- For record fields, returns a list of all fields needed.
168-
- For signal fields, it returns a dictionary of all fields needed,
169-
with keys = field and value = list of 1 or 0 indicating channel for the field
166+
Returns
167+
-------
168+
write_fields : list or dict
169+
For record fields, returns a list of all fields needed. For
170+
signal fields, it returns a dictionary of all fields needed,
171+
with keys = field and value = list of 1 or 0 indicating
172+
channel for the field
170173
171174
"""
172175
if spec_type == 'record':
@@ -181,11 +184,11 @@ def get_write_subset(self, spec_type):
181184
continue
182185
# If the field is required by default or has been defined by the user
183186
if fieldspecs[f].write_req or getattr(self, f) is not None:
184-
rf=f
187+
rf = f
185188
# Add the field and its recursive dependencies
186189
while rf is not None:
187190
write_fields.append(rf)
188-
rf=fieldspecs[rf].dependency
191+
rf = fieldspecs[rf].dependency
189192
# Add comments if any
190193
if getattr(self, 'comments') is not None:
191194
write_fields.append('comments')
@@ -262,6 +265,10 @@ def wrheader(self, write_dir=''):
262265
- Get the fields used to write the header for this instance.
263266
- Check each required field.
264267
268+
Parameters
269+
----------
270+
write_dir : str, optional
271+
The output directory in which the header is written.
265272
"""
266273

267274
# Get all the fields used to write the header
@@ -289,12 +296,22 @@ def wrheader(self, write_dir=''):
289296
def get_write_fields(self):
290297
"""
291298
Get the list of fields used to write the header, separating
292-
record and signal specification fields.
299+
record and signal specification fields. Returns the default
300+
required fields, the user defined fields,
301+
and their dependencies.
293302
294303
Does NOT include `d_signal` or `e_d_signal`.
295304
296-
Returns the default required fields, the user defined fields, and their dependencies.
297-
rec_write_fields includes 'comment' if present.
305+
Returns
306+
-------
307+
rec_write_fields : list
308+
Record specification fields to be written. Includes
309+
'comment' if present.
310+
sig_write_fields : dict
311+
Dictionary of signal specification fields to be written,
312+
with values equal to the channels that need to be present
313+
for each field.
314+
298315
"""
299316

300317
# Record specification fields
@@ -304,8 +321,7 @@ def get_write_fields(self):
304321
if self.comments != None:
305322
rec_write_fields.append('comments')
306323

307-
# Determine whether there are signals. If so, get their required
308-
# fields.
324+
# Get required signal fields if signals are present.
309325
self.check_field('n_sig')
310326

311327
if self.n_sig > 0:

wfdb/io/_signal.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -574,38 +574,40 @@ def calc_checksum(self, expanded=False):
574574
cs = [int(c) for c in cs]
575575
return cs
576576

577-
# Write each of the specified dat files
578577
def wr_dat_files(self, expanded=False, write_dir=''):
578+
"""
579+
Write each of the specified dat files
579580
581+
"""
580582
# Get the set of dat files to be written, and
581583
# the channels to be written to each file.
582-
file_names, datchannels = orderedsetlist(self.file_name)
584+
file_names, dat_channels = describe_list_indices(self.file_name)
583585

584586
# Get the fmt and byte offset corresponding to each dat file
585-
datfmts={}
586-
datoffsets={}
587+
dat_fmts = {}
588+
dat_offsets = {}
587589
for fn in file_names:
588-
datfmts[fn] = self.fmt[datchannels[fn][0]]
590+
dat_fmts[fn] = self.fmt[dat_channels[fn][0]]
589591

590592
# byte_offset may not be present
591593
if self.byte_offset is None:
592-
datoffsets[fn] = 0
594+
dat_offsets[fn] = 0
593595
else:
594-
datoffsets[fn] = self.byte_offset[datchannels[fn][0]]
596+
dat_offsets[fn] = self.byte_offset[dat_channels[fn][0]]
595597

596598
# Write the dat files
597599
if expanded:
598600
for fn in file_names:
599-
wr_dat_file(fn, datfmts[fn], None , datoffsets[fn], True,
600-
[self.e_d_signal[ch] for ch in datchannels[fn]],
601+
wr_dat_file(fn, dat_fmts[fn], None , dat_offsets[fn], True,
602+
[self.e_d_signal[ch] for ch in dat_channels[fn]],
601603
self.samps_per_frame, write_dir=write_dir)
602604
else:
603605
# Create a copy to prevent overwrite
604606
dsig = self.d_signal.copy()
605607
for fn in file_names:
606-
wr_dat_file(fn, datfmts[fn],
607-
dsig[:, datchannels[fn][0]:datchannels[fn][-1]+1],
608-
datoffsets[fn], write_dir=write_dir)
608+
wr_dat_file(fn, dat_fmts[fn],
609+
dsig[:, dat_channels[fn][0]:dat_channels[fn][-1]+1],
610+
dat_offsets[fn], write_dir=write_dir)
609611

610612

611613
def smooth_frames(self, sigtype='physical'):
@@ -689,7 +691,7 @@ def rd_segment(file_name, dirname, pb_dir, n_sig, fmt, sig_len, byte_offset,
689691

690692
# Get the set of dat files, and the
691693
# channels that belong to each file.
692-
file_name, datchannel = orderedsetlist(file_name)
694+
file_name, datchannel = describe_list_indices(file_name)
693695

694696
# Some files will not be read depending on input channels.
695697
# Get the the wanted fields only.
@@ -1487,25 +1489,36 @@ def wr_dat_file(file_name, fmt, d_signal, byte_offset, expanded=False,
14871489
f.close()
14881490

14891491

1490-
def orderedsetlist(fulllist):
1492+
def describe_list_indices(full_list):
14911493
"""
1492-
Returns the unique elements in a list in the order that they appear.
1493-
Also returns the indices of the original list that correspond to each output element.
1494+
Parameters
1495+
----------
1496+
full_list : list
1497+
The list of items to order and
1498+
1499+
Returns
1500+
-------
1501+
unique_elements : list
1502+
A list of the unique elements of the list, in the order in which
1503+
they first appear.
1504+
element_indices : dict
1505+
A dictionary of lists for each unique element, giving all the
1506+
indices in which they appear in the original list.
1507+
14941508
"""
1495-
uniquelist = []
1496-
original_inds = {}
1509+
unique_elements = []
1510+
element_indices = {}
14971511

1498-
for i in range(0, len(fulllist)):
1499-
item = fulllist[i]
1512+
for i in range(len(full_list)):
1513+
item = full_list[i]
15001514
# new item
1501-
if item not in uniquelist:
1502-
uniquelist.append(item)
1503-
original_inds[item] = [i]
1515+
if item not in unique_elements:
1516+
unique_elements.append(item)
1517+
element_indices[item] = [i]
15041518
# previously seen item
15051519
else:
1506-
original_inds[item].append(i)
1507-
return uniquelist, original_inds
1508-
1520+
element_indices[item].append(i)
1521+
return unique_elements, element_indices
15091522

15101523

15111524
def downround(x, base):

wfdb/io/record.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
# The check_field_cohesion() function will be called in wrheader which checks all the header fields.
66
# The check_sig_cohesion() function will be called in wrsamp in wrdat to check the d_signal against the header fields.
77

8-
9-
from collections import OrderedDict
108
import datetime
119
import multiprocessing
1210
import numpy as np
@@ -46,7 +44,7 @@ def check_field(self, field, required_channels=None):
4644
----------
4745
field : str
4846
The field name
49-
channels : list, optional
47+
required_channels : list, optional
5048
Used for signal specification fields. Species the channels
5149
to check. Other channels can be None.
5250
@@ -57,8 +55,8 @@ def check_field(self, field, required_channels=None):
5755
if item is None:
5856
raise Exception('Missing field required: %s' % field)
5957

60-
if channels == 'all':
61-
channels = range(len(item))
58+
if required_channels == 'all':
59+
required_channels = range(len(item))
6260

6361
# We should have a list specifying these automatically.
6462

@@ -138,10 +136,10 @@ def check_field(self, field, required_channels=None):
138136
# Check for file_name characters
139137
accepted_string = re.match('[-\w]+\.?[\w]+', item[ch])
140138
if not accepted_string or accepted_string.string != item[ch]:
141-
raise ValueError('File names should only contain alphanumerics, hyphens, and an extension. eg. record_100.dat')
139+
raise ValueError('File names should only contain alphanumerics, hyphens, and an extension. eg. record-100.dat')
142140
# Check that dat files are grouped together
143-
if orderedsetlist(self.file_name)[0] != orderednoconseclist(self.file_name):
144-
raise ValueError('file_name error: all entries for signals that share a given file must be consecutive')
141+
if not is_monotonic(self.file_name):
142+
raise ValueError('Signals in a record that share a given file must be consecutive.')
145143
elif field == 'fmt':
146144
if item[ch] not in _signal.dat_fmts:
147145
raise ValueError('File formats must be valid WFDB dat formats:', _signal.dat_fmts)
@@ -156,7 +154,7 @@ def check_field(self, field, required_channels=None):
156154
raise ValueError('byte_offset values must be non-negative integers')
157155
elif field == 'adc_gain':
158156
if item[ch] <= 0:
159-
raise ValueError('adc_gain values must be positive numbers')
157+
raise ValueError('adc_gain values must be positive')
160158
elif field == 'baseline':
161159
# Original WFDB library 10.5.24 only has 4 bytes for
162160
# baseline.
@@ -286,7 +284,7 @@ def check_item_type(item, field_name, allowed_types, expect_list=False,
286284
"""
287285
Check the item's type against a set of allowed types.
288286
Vary the print message regarding whether the item can be None.
289-
Helper to `BaseRecord.check_field_type`.
287+
Helper to `BaseRecord.check_field`.
290288
291289
Parameters
292290
----------
@@ -315,7 +313,7 @@ def check_item_type(item, field_name, allowed_types, expect_list=False,
315313

316314
for ch in range(len(item)):
317315
# Check whether the field may be None
318-
if ch_in required_channels:
316+
if ch in required_channels:
319317
allowed_types_ch = allowed_types + (type(None),)
320318
else:
321319
allowed_types_ch = allowed_types
@@ -1200,8 +1198,9 @@ def rdsamp(record_name, sampfrom=0, sampto='end', channels='all', pb_dir=None):
12001198
channel =[1,3])
12011199
12021200
"""
1203-
1204-
record = rdrecord(record_name, sampfrom, sampto, channels, True, pb_dir, True)
1201+
record = rdrecord(record_name=record_name, sampfrom=sampfrom,
1202+
sampto=sampto, channels=channels, physical=True,
1203+
pb_dir=pb_dir, m2s=True)
12051204

12061205
signals = record.p_signal
12071206
fields = {}
@@ -1347,33 +1346,24 @@ def wrsamp(record_name, fs, units, sig_name, p_signal=None, d_signal=None,
13471346
record.wrsamp(write_dir=write_dir)
13481347

13491348

1350-
# Returns the unique elements in a list in the order that they appear.
1351-
# Also returns the indices of the original list that correspond to each output element.
1352-
def orderedsetlist(fulllist):
1353-
uniquelist = []
1354-
original_inds = {}
1349+
def is_monotonic(full_list):
1350+
"""
1351+
Determine whether elements in a list are monotonic. ie. unique
1352+
elements are clustered together.
1353+
1354+
ie. [5,5,3,4] is, [5,3,5] is not.
1355+
"""
1356+
prev_elements = set({full_list[0]})
1357+
prev_item = full_list[0]
13551358

1356-
for i in range(0, len(fulllist)):
1357-
item = fulllist[i]
1358-
# new item
1359-
if item not in uniquelist:
1360-
uniquelist.append(item)
1361-
original_inds[item] = [i]
1362-
# previously seen item
1363-
else:
1364-
original_inds[item].append(i)
1365-
return uniquelist, original_inds
1366-
1367-
# Returns elements in a list without consecutive repeated values.
1368-
def orderednoconseclist(fulllist):
1369-
noconseclist = [fulllist[0]]
1370-
if len(fulllist) == 1:
1371-
return noconseclist
1372-
for i in fulllist:
1373-
if i!= noconseclist[-1]:
1374-
noconseclist.append(i)
1375-
return noconseclist
1359+
for item in full_list:
1360+
if item != prev_item:
1361+
if item in prev_elements:
1362+
return False
1363+
prev_item = item
1364+
prev_elements.add(item)
13761365

1366+
return True
13771367

13781368
def dl_database(db_dir, dl_dir, records='all', annotators='all',
13791369
keep_subdirs=True, overwrite = False):

0 commit comments

Comments
 (0)