Skip to content

Commit 652d869

Browse files
committed
continue refactor, readwrite records works
1 parent 0ce255a commit 652d869

File tree

5 files changed

+615
-178
lines changed

5 files changed

+615
-178
lines changed

demo.ipynb

Lines changed: 422 additions & 33 deletions
Large diffs are not rendered by default.

wfdb/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .record import (Record, MultiRecord, rdheader, rdrecord, rdsamp, wrsamp,
2-
dl_database, sig_classes)
2+
dl_database, SIGNAL_CLASSES)
33
from ._signal import est_res, wr_dat_file
44
from .annotation import (Annotation, rdann, wrann, show_ann_labels,
55
show_ann_classes)

wfdb/io/_header.py

Lines changed: 120 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from . import download
99
from . import _signal
1010

11-
11+
import pdb
1212
int_types = (int, np.int64, np.int32, np.int16, np.int8)
1313
float_types = int_types + (float, np.float64, np.float32)
1414

1515
"""
16-
WFDB field specifications for each field.
16+
WFDB field specifications for each field. The indexes are the field
17+
names.
1718
1819
Parameters
1920
----------
@@ -63,6 +64,7 @@
6364
index=['record_name', 'n_seg', 'n_sig', 'fs', 'counter_freq',
6465
'base_counter', 'sig_len', 'base_time', 'base_date'],
6566
columns=_SPECIFICATION_COLUMNS,
67+
dtype='object',
6668
data=[[(str,), '', None, True, None, None], # record_name
6769
[int_types, '/', 'record_name', True, None, None], # n_seg
6870
[int_types, ' ', 'record_name', True, None, None], # n_sig
@@ -80,6 +82,7 @@
8082
'adc_gain', 'baseline', 'units', 'adc_res', 'adc_zero',
8183
'init_value', 'checksum', 'block_size', 'sig_name'],
8284
columns=_SPECIFICATION_COLUMNS,
85+
dtype='object',
8386
data=[[(str,), '', None, True, None, None], # file_name
8487
[(str,), ' ', 'file_name', True, None, None], # fmt
8588
[int_types, 'x', 'fmt', False, 1, None], # samps_per_frame
@@ -100,6 +103,7 @@
100103
SEGMENT_SPECS = pd.DataFrame(
101104
index=['seg_name', 'seg_len'],
102105
columns=_SPECIFICATION_COLUMNS,
106+
dtype='object',
103107
data=[[(str), '', None, True, None, None], # seg_name
104108
[int_types, ' ', 'seg_name', True, None, None], # seg_len
105109
]
@@ -170,7 +174,7 @@ def get_write_subset(self, spec_type):
170174

171175
# Remove the n_seg requirement for single segment items
172176
if not hasattr(self, 'n_seg'):
173-
del(record_specs['n_seg'])
177+
record_specs.drop('n_seg', inplace=True)
174178

175179
for field in record_specs.index[-1::-1]:
176180
# Continue if the field has already been included
@@ -198,13 +202,13 @@ def get_write_subset(self, spec_type):
198202
for ch in range(self.n_sig):
199203
# The fields needed for this channel
200204
write_fields_ch = []
201-
for field in signal_specs[-1::-1]:
205+
for field in signal_specs.index[-1::-1]:
202206
if field in write_fields_ch:
203207
continue
204208

205209
item = getattr(self, field)
206210
# If the field is required by default or has been defined by the user
207-
if signal_specs.loc[field, 'write_req'] or (item is not None and item[ch] is not None):
211+
if signal_specs.loc[field, 'write_required'] or (item is not None and item[ch] is not None):
208212
req_field = field
209213
# Add the field and its recursive dependencies
210214
while req_field is not None:
@@ -238,32 +242,43 @@ class HeaderMixin(BaseHeaderMixin):
238242

239243
def set_defaults(self):
240244
"""
241-
Set defaults for fields needed to write the header if they have defaults.
242-
This is NOT called by rdheader. It is only automatically called by the gateway wrsamp for convenience.
243-
It is also not called by wrhea (this may be changed in the future) since
244-
it is supposed to be an explicit function.
245+
Set defaults for fields needed to write the header if they have
246+
defaults.
247+
248+
Notes
249+
-----
250+
- This is NOT called by `rdheader`. It is only automatically
251+
called by the gateway `wrsamp` for convenience.
252+
- This is also not called by `wrheader` since it is supposed to
253+
be an explicit function.
254+
- This is not responsible for initializing the attributes. That
255+
is done by the constructor.
245256
246-
Not responsible for initializing the
247-
attributes. That is done by the constructor.
248257
"""
249258
rfields, sfields = self.get_write_fields()
250259
for f in rfields:
251260
self.set_default(f)
252261
for f in sfields:
253262
self.set_default(f)
254263

255-
256264
def wrheader(self, write_dir=''):
257265
"""
258266
Write a wfdb header file. The signals are not used. Before
259267
writing:
260268
- Get the fields used to write the header for this instance.
261269
- Check each required field.
270+
- Check that the fields are cohesive with one another.
262271
263272
Parameters
264273
----------
265274
write_dir : str, optional
266275
The output directory in which the header is written.
276+
277+
Notes
278+
-----
279+
This function does NOT call `set_defaults`. Essential fields
280+
must be set beforehand.
281+
267282
"""
268283

269284
# Get all the fields used to write the header
@@ -338,29 +353,31 @@ def set_default(self, field):
338353

339354
# Record specification fields
340355
if field in RECORD_SPECS.index:
341-
# Return if no default to set, or if the field is already present.
356+
# Return if no default to set, or if the field is already
357+
# present.
342358
if RECORD_SPECS.loc[field, 'write_default'] is None or getattr(self, field) is not None:
343359
return
344360
setattr(self, field, RECORD_SPECS.loc[field, 'write_default'])
345361

346362
# Signal specification fields
347363
# Setting entire list default, not filling in blanks in lists.
348-
elif field in SIGNAL_FIELDS.index:
364+
elif field in SIGNAL_SPECS.index:
349365

350366
# Specific dynamic case
351367
if field == 'file_name' and self.file_name is None:
352-
self.file_name = self.n_sig*[self.record_name+'.dat']
368+
self.file_name = self.n_sig * [self.record_name + '.dat']
353369
return
354370

355371
item = getattr(self, field)
356372

357-
# Return if no default to set, or if the field is already present.
373+
# Return if no default to set, or if the field is already
374+
# present.
358375
if SIGNAL_SPECS.loc[field, 'write_default'] is None or item is not None:
359376
return
360377

361378
# Set more specific defaults if possible
362379
if field == 'adc_res' and self.fmt is not None:
363-
self.adc_res=_signal.wfdbfmtres(self.fmt)
380+
self.adc_res = _signal.wfdbfmtres(self.fmt)
364381
return
365382

366383
setattr(self, field,
@@ -403,46 +420,60 @@ def check_field_cohesion(self, rec_write_fields, sig_write_fields):
403420
raise ValueError('Each file_name (dat file) specified must have the same byte offset')
404421

405422

406-
407423
def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
408-
# Write a header file using the specified fields
409-
header_lines = []
424+
"""
425+
Write a header file using the specified fields
426+
427+
Parameters
428+
----------
429+
430+
rec_write_fields : list
431+
List of record specification fields to write
432+
sig_write_fields : dict
433+
Dictionary of signal specification fields to write, values
434+
being equal to a list of channels to write for each field.
435+
write_dir : str
436+
The directory in which to write the header file
410437
438+
"""
411439
# Create record specification line
412440
record_line = ''
413441
# Traverse the ordered dictionary
414-
for field in RECORD_SPECS:
442+
for field in RECORD_SPECS.index:
415443
# If the field is being used, add it with its delimiter
416444
if field in rec_write_fields:
417445
stringfield = str(getattr(self, field))
418446
# If fs is float, check whether it as an integer
419447
if field == 'fs' and isinstance(self.fs, float):
420448
if round(self.fs, 8) == float(int(self.fs)):
421449
stringfield = str(int(self.fs))
422-
record_line = record_line + RECORD_SPECS[field].delimiter + stringfield
423-
header_lines.append(record_line)
450+
record_line += RECORD_SPECS.loc[field, 'delimiter'] + stringfield
451+
452+
header_lines = [record_line]
424453

425454
# Create signal specification lines (if any) one channel at a time
426-
if self.n_sig>0:
427-
signallines = self.n_sig*['']
455+
if self.n_sig > 0:
456+
signal_lines = self.n_sig * ['']
428457
for ch in range(self.n_sig):
429-
# Traverse the ordered dictionary
430-
for field in SIGNAL_FIELDS:
431-
# If the field is being used, add each of its elements with the delimiter to the appropriate line
432-
if field in sig_write_fields and sig_write_fields[field][ch]:
433-
signallines[ch]=signallines[ch] + SIGNAL_FIELDS[field].delimiter + str(getattr(self, field)[ch])
458+
# Traverse the signal fields
459+
for field in SIGNAL_SPECS.index:
460+
# If the field is being used, add each of its
461+
# elements with the delimiter to the appropriate
462+
# line
463+
if field in sig_write_fields and ch in sig_write_fields[field]:
464+
signal_lines[ch] += SIGNAL_SPECS.loc[field, 'delimiter'] + str(getattr(self, field)[ch])
434465
# The 'baseline' field needs to be closed with ')'
435-
if field== 'baseline':
436-
signallines[ch]=signallines[ch] +')'
466+
if field == 'baseline':
467+
signal_lines[ch] += ')'
437468

438-
header_lines = header_lines + signallines
469+
header_lines += signal_lines
439470

440471
# Create comment lines (if any)
441472
if 'comments' in rec_write_fields:
442-
comment_lines = ['# '+comment for comment in self.comments]
443-
header_lines = header_lines + comment_lines
473+
comment_lines = ['# ' + comment for comment in self.comments]
474+
header_lines += comment_lines
444475

445-
lines_to_file(self.record_name+'.hea', write_dir, header_lines)
476+
lines_to_file(self.record_name + '.hea', write_dir, header_lines)
446477

447478

448479
class MultiHeaderMixin(BaseHeaderMixin):
@@ -523,35 +554,38 @@ def check_field_cohesion(self):
523554
raise ValueError("The sum of the 'seg_len' fields do not match the 'sig_len' field")
524555

525556

526-
# Write a header file using the specified fields
527-
def wr_header_file(self, write_fields, write_dir):
528557

529-
header_lines=[]
558+
def wr_header_file(self, write_fields, write_dir):
559+
"""
560+
Write a header file using the specified fields
530561
562+
"""
531563
# Create record specification line
532564
record_line = ''
533565
# Traverse the ordered dictionary
534-
for field in RECORD_SPECS:
566+
for field in RECORD_SPECS.index:
535567
# If the field is being used, add it with its delimiter
536568
if field in write_fields:
537-
record_line = record_line + RECORD_SPECS[field].delimiter + str(getattr(self, field))
538-
header_lines.append(record_line)
569+
record_line += RECORD_SPECS.loc[field, 'delimiter'] + str(getattr(self, field))
570+
571+
header_lines = [record_line]
539572

540573
# Create segment specification lines
541-
segmentlines = self.n_seg*['']
542-
# For both fields, add each of its elements with the delimiter to the appropriate line
543-
for field in ['seg_name', 'seg_name']:
544-
for segnum in range(0, self.n_seg):
545-
segmentlines[segnum] = segmentlines[segnum] + SEGMENT_SPECS[field].delimiter + str(getattr(self, field)[segnum])
574+
segment_lines = self.n_seg * ['']
575+
# For both fields, add each of its elements with the delimiter
576+
# to the appropriate line
577+
for field in SEGMENT_SPECS.index:
578+
for seg_num in range(self.n_seg):
579+
segment_lines[seg_num] += SEGMENT_SPECS.loc[field, 'delimiter'] + str(getattr(self, field)[seg_num])
546580

547-
header_lines = header_lines + segmentlines
581+
header_lines = header_lines + segment_lines
548582

549583
# Create comment lines (if any)
550584
if 'comments' in write_fields:
551-
comment_lines = ['# '+comment for comment in self.comments]
552-
header_lines = header_lines + comment_lines
585+
comment_lines = ['# '+ comment for comment in self.comments]
586+
header_lines += comment_lines
553587

554-
lines_to_file(self.record_name+'.hea', header_lines, write_dir)
588+
lines_to_file(self.record_name + '.hea', header_lines, write_dir)
555589

556590

557591
def get_sig_segments(self, sig_name=None):
@@ -662,45 +696,49 @@ def _read_record_line(record_line):
662696
return record_fields
663697

664698

665-
# Extract fields from signal line strings into a dictionary
666699
def _read_signal_lines(signal_lines):
700+
"""
701+
Extract fields from a list of signal line strings into a dictionary.
702+
703+
"""
704+
n_sig = len(signal_lines)
667705
# Dictionary for signal fields
668706
signal_fields = {}
669707

670708
# Each dictionary field is a list
671-
for field in SIGNAL_FIELDS:
672-
signal_fields[field] = [None]*len(signal_lines)
709+
for field in SIGNAL_SPECS.index:
710+
signal_fields[field] = n_sig * [None]
673711

674712
# Read string fields from signal line
675-
for i in range(len(signal_lines)):
676-
(signal_fields['file_name'][i], signal_fields['fmt'][i],
677-
signal_fields['samps_per_frame'][i], signal_fields['skew'][i],
678-
signal_fields['byte_offset'][i], signal_fields['adc_gain'][i],
679-
signal_fields['baseline'][i], signal_fields['units'][i],
680-
signal_fields['adc_res'][i], signal_fields['adc_zero'][i],
681-
signal_fields['init_value'][i], signal_fields['checksum'][i],
682-
signal_fields['block_size'][i],
683-
signal_fields['sig_name'][i]) = _rx_signal.findall(signal_lines[i])[0]
684-
685-
for field in SIGNAL_FIELDS:
713+
for ch in range(n_sig):
714+
(signal_fields['file_name'][ch], signal_fields['fmt'][ch],
715+
signal_fields['samps_per_frame'][ch], signal_fields['skew'][ch],
716+
signal_fields['byte_offset'][ch], signal_fields['adc_gain'][ch],
717+
signal_fields['baseline'][ch], signal_fields['units'][ch],
718+
signal_fields['adc_res'][ch], signal_fields['adc_zero'][ch],
719+
signal_fields['init_value'][ch], signal_fields['checksum'][ch],
720+
signal_fields['block_size'][ch],
721+
signal_fields['sig_name'][ch]) = _rx_signal.findall(signal_lines[ch])[0]
722+
723+
for field in SIGNAL_SPECS.index:
686724
# Replace empty strings with their read defaults (which are mostly None)
687725
# Note: Never set a field to None. [None]* n_sig is accurate, indicating
688726
# that different channels can be present or missing.
689-
if signal_fields[field][i] == '':
690-
signal_fields[field][i] = SIGNAL_FIELDS[field].read_default
727+
if signal_fields[field][ch] == '':
728+
signal_fields[field][ch] = SIGNAL_SPECS.loc[field, 'read_default']
691729

692730
# Special case: missing baseline defaults to ADCzero if present
693-
if field == 'baseline' and signal_fields['adc_zero'][i] != '':
694-
signal_fields['baseline'][i] = int(signal_fields['adc_zero'][i])
731+
if field == 'baseline' and signal_fields['adc_zero'][ch] != '':
732+
signal_fields['baseline'][ch] = int(signal_fields['adc_zero'][ch])
695733
# Typecast non-empty strings for numerical fields
696734
else:
697-
if SIGNAL_FIELDS[field].allowed_types is int_types:
698-
signal_fields[field][i] = int(signal_fields[field][i])
699-
elif SIGNAL_FIELDS[field].allowed_types is float_types:
700-
signal_fields[field][i] = float(signal_fields[field][i])
701-
# Special case: gain of 0 means 200
702-
if field == 'adc_gain' and signal_fields['adc_gain'][i] == 0:
703-
signal_fields['adc_gain'][i] = 200.
735+
if SIGNAL_SPECS.loc[field, 'allowed_types'] is int_types:
736+
signal_fields[field][ch] = int(signal_fields[field][ch])
737+
elif SIGNAL_SPECS.loc[field, 'allowed_types'] is float_types:
738+
signal_fields[field][ch] = float(signal_fields[field][ch])
739+
# Special case: adc_gain of 0 means 200
740+
if field == 'adc_gain' and signal_fields['adc_gain'][ch] == 0:
741+
signal_fields['adc_gain'][ch] = 200.
704742

705743
return signal_fields
706744

@@ -714,21 +752,16 @@ def _read_segment_lines(segment_lines):
714752
segment_fields = {}
715753

716754
# Each dictionary field is a list
717-
for field in SEGMENT_SPECS:
718-
segment_fields[field] = [None]*len(segment_lines)
755+
for field in SEGMENT_SPECS.index:
756+
segment_fields[field] = [None] * len(segment_lines)
719757

720758
# Read string fields from signal line
721-
for i in range(0, len(segment_lines)):
759+
for i in range(len(segment_lines)):
722760
(segment_fields['seg_name'][i], segment_fields['seg_len'][i]) = _rx_segment.findall(segment_lines[i])[0]
723761

724-
for field in SEGMENT_SPECS:
725-
# Replace empty strings with their read defaults (which are mostly None)
726-
if segment_fields[field][i] == '':
727-
segment_fields[field][i] = SEGMENT_SPECS[field].read_default
728-
# Typecast non-empty strings for numerical field
729-
else:
730-
if field == 'seg_len':
731-
segment_fields[field][i] = int(segment_fields[field][i])
762+
# Typecast strings for numerical field
763+
if field == 'seg_len':
764+
segment_fields['seg_len'][i] = int(segment_fields['seg_len'][i])
732765

733766
return segment_fields
734767

0 commit comments

Comments
 (0)