Skip to content

Commit d49f6db

Browse files
committed
force user input of used label field when writing annotations
1 parent 4dce30c commit d49f6db

File tree

1 file changed

+66
-57
lines changed

1 file changed

+66
-57
lines changed

wfdb/io/annotation.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'description':(list, np.ndarray),
3434
'custom_labels': (pd.DataFrame, list, tuple)}
3535

36-
str_types = (str, np.str_)
36+
STR_TYPES = (str, np.str_)
3737

3838
# Standard WFDB annotation file extensions
3939
ANN_EXTENSIONS = pd.DataFrame(data=[
@@ -244,49 +244,48 @@ def contained_data_fields(self):
244244
"""
245245
return [f for f in ANN_DATA_FIELDS if hasattr(self, f)]
246246

247-
def wrann(self, write_fs=False, write_dir=''):
247+
def wrann(self, label_field='symbol', write_fs=False, write_dir=''):
248248
"""
249249
Write a WFDB annotation file from this object.
250250
251-
!!! Should we force an argument to choose the label field used?
252-
253251
Parameters
254252
----------
253+
label_field : str, optional
254+
The field used to write the label information. The
255+
annotation object must have this attribute defined.
256+
Must be either 'label_store' or 'symbol'. If 'symbol' is
257+
chosen, that attribute will be used to calculate label_store
258+
which will be written to the file.
255259
write_fs : bool, optional
256260
Whether to write the `fs` attribute to the file.
257-
258-
Notes
259-
-----
260-
The label_store field will be generated if necessary
261+
write_dir : str, optional
262+
The directory in which to write the annotation file
261263
262264
"""
265+
# Validate input parameters
266+
if label_field not in ['label_store', 'symbol']:
267+
raise Exception("'label_field' must be set to either 'symbol' or 'label_store'.")
263268

264-
# Check the presence of vital fields
265-
contained_label_fields = self._contained_label_fields()
266-
if not contained_label_fields:
267-
raise Exception('At least one annotation label field is required to write the annotation: ', ANN_LABEL_FIELDS)
268-
for field in ['record_name', 'extension']:
269-
if getattr(self, field) is None:
270-
raise Exception('Missing required field for writing annotation file: ',field)
271-
272-
# Check the validity of individual fields
273-
self.check_fields()
269+
# Check the validity of individual fields used to write the file
270+
self.check_write_fields(label_field=label_field)
274271

275272
# Standardize the format of the custom_labels field
276273
self._custom_labels_to_df()
277274

275+
# Check the cohesion of the fields
276+
self.check_field_cohesion()
277+
278278
# Create the label map used in this annotaion
279279
self._create_label_map()
280280

281-
# Set the label_store field if necessary
282-
if 'label_store' not in contained_label_fields:
283-
self.convert_label_attribute(source_field=contained_label_fields[0],
281+
# Set the label_store field if needed. Requires the label map.
282+
if label_field =='symbol':
283+
self.convert_label_attribute(source_field='symbol',
284284
target_field='label_store')
285285

286-
# Check the cohesion of the fields
287-
self.check_field_cohesion()
288286

289-
# Write the header file using the specified fields
287+
288+
# Write the annotation file using the specified fields
290289
self.wr_ann_file(write_fs=write_fs, write_dir=write_dir)
291290

292291
def _contained_label_fields(self):
@@ -295,13 +294,27 @@ def _contained_label_fields(self):
295294
"""
296295
return [field for field in ANN_LABEL_FIELDS if getattr(self, field)]
297296

298-
def check_fields(self):
297+
def check_write_fields(self, label_field):
299298
"""
300-
Check the set fields of the annotation object
299+
Check all fields of this object that may be used to write an
300+
annotation file.
301+
302+
Mandatory fields will be checked. Optional fields will be
303+
checked if they are defined, and hence liable to affect the
304+
output file.
305+
301306
"""
302307
for field in ANN_FIELDS:
303-
if hasattr(self, field):
304-
self.check_field(field)
308+
# Mandatory check
309+
if field in ('sample', 'record_name', 'extension', label_field):
310+
if getattr(self, field) is not None:
311+
self.check_field(field)
312+
else:
313+
raise Exception('Missing required field for writing annotation file: ', field)
314+
# Check if defined
315+
else:
316+
if getattr(self, field) is not None:
317+
self.check_field(field)
305318

306319
def check_field(self, field):
307320
"""
@@ -380,13 +393,13 @@ def check_field(self, field):
380393
if not hasattr(label_store[i], '__index__'):
381394
raise TypeError('The label_store values of the '+field+' field must be integer-like')
382395

383-
if not isinstance(symbol[i], str_types) or len(symbol[i]) not in [1,2,3]:
396+
if not isinstance(symbol[i], STR_TYPES) or len(symbol[i]) not in [1,2,3]:
384397
raise ValueError('The symbol values of the '+field+' field must be strings of length 1 to 3')
385398

386399
if bool(re.search('[ \t\n\r\f\v]', symbol[i])):
387400
raise ValueError('The symbol values of the '+field+' field must not contain whitespace characters')
388401

389-
if not isinstance(description[i], str_types):
402+
if not isinstance(description[i], STR_TYPES):
390403
raise TypeError('The description values of the '+field+' field must be strings')
391404

392405
# Would be good to enfore this but existing garbage annotations have tabs and newlines...
@@ -398,7 +411,7 @@ def check_field(self, field):
398411
uniq_elements = set(item)
399412

400413
for e in uniq_elements:
401-
if not isinstance(e, str_types):
414+
if not isinstance(e, STR_TYPES):
402415
raise TypeError("Subelements of the '{}' field must be strings".format(field))
403416

404417
if field == 'symbol':
@@ -552,7 +565,7 @@ def _get_available_label_stores(self):
552565
else:
553566
raise ValueError('No label fields are defined. At least one of the following is required: ', ANN_LABEL_FIELDS)
554567

555-
# We are using 'label_store', the steps are slightly different.
568+
# If we are using 'label_store', the steps are slightly different.
556569

557570
# Get the unused label_store values
558571
if usefield == 'label_store':
@@ -840,9 +853,8 @@ def _compact_fields(self):
840853

841854

842855
def sym_to_aux(self):
843-
# Move non-encoded symbol elements into the aux_note field
856+
"Move non-encoded symbol elements into the aux_note field"
844857
self.check_field('symbol')
845-
846858
# Non-encoded symbols
847859
label_table_map = self._create_label_map(inplace=False)
848860
external_syms = set(self.symbol) - set(label_table_map['symbol'].values)
@@ -880,7 +892,7 @@ def get_contained_labels(self):
880892
if self.custom_labels:
881893
self._custom_labels_to_df()
882894

883-
self.check_field_cohesion()
895+
# self.check_field_cohesion()
884896

885897
# Merge the standard wfdb labels with the custom labels.
886898
# custom labels values overwrite standard wfdb if overlap.
@@ -1013,18 +1025,15 @@ def convert_label_attribute(self, source_field, target_field):
10131025
The destination label attribute
10141026
10151027
"""
1016-
if inplace and not overwrite:
1017-
if getattr(self, target_field) is not None:
1018-
return
1028+
self._create_label_map()
10191029

1020-
label_map = self._create_label_map(inplace=False)
1021-
label_map.set_index(source_field, inplace=True)
10221030

1023-
target_item = label_map.loc[getattr(self, source_field), target_field].values
1031+
target_item = self.__label_map__.loc[getattr(self, source_field), target_field].values
10241032

1025-
if target_field != 'label_store':
1026-
# Should already be int64 dtype if target is label_store
1027-
target_item = list(target_item)
1033+
# Shouldn't need this? Just leave as list.
1034+
# if target_field != 'label_store':
1035+
# # Should already be int64 dtype if target is label_store
1036+
# target_item = list(target_item)
10281037

10291038
setattr(self, target_field, target_item)
10301039

@@ -1042,12 +1051,9 @@ def to_df(self, fields=None):
10421051
Create a pandas DataFrame from the Annotation object
10431052
10441053
"""
1045-
fields = fields or ANN_DATA_FIELDS[:6]
1054+
fields = fields or ANN_DATA_FIELDS
10461055

1047-
df = pd.DataFrame(data={'sample':self.sample, 'symbol':self.symbol,
1048-
'subtype':self.subtype, 'chan':self.chan, 'num':self.num,
1049-
'aux_note':self.aux_note}, columns=['sample', 'symbol', 'subtype',
1050-
'chan', 'aux_note'])
1056+
df = pd.DataFrame(data={field:getattr(self, field) for field in fields})
10511057
return df
10521058

10531059

@@ -1284,17 +1290,20 @@ def wrann(record_name, extension, sample, symbol=None, subtype=None, chan=None,
12841290
custom_labels=custom_labels)
12851291

12861292
# Find out which input field describes the labels
1287-
if symbol is None:
1288-
if label_store is None:
1289-
raise Exception("Either the 'symbol' field or the 'label_store' field must be set")
1290-
else:
1291-
if label_store is None:
1292-
annotation.sym_to_aux()
1293-
else:
1293+
if symbol:
1294+
if label_store:
12941295
raise Exception("Only one of the 'symbol' and 'label_store' fields may be input, for describing annotation labels")
1296+
else:
1297+
label_field = 'symbol'
1298+
annotation.sym_to_aux()
1299+
else:
1300+
if not label_store:
1301+
raise Exception("Either the 'symbol' field or the 'label_store' field must be set")
1302+
label_field = 'label_store'
12951303

12961304
# Perform field checks and write the annotation file
1297-
annotation.wrann(write_fs=True, write_dir=write_dir)
1305+
annotation.wrann(label_field=label_field, write_fs=True,
1306+
write_dir=write_dir)
12981307

12991308

13001309
def show_ann_labels():

0 commit comments

Comments
 (0)