1
1
import os
2
+ import pdb
2
3
import shutil
3
4
import unittest
4
5
5
6
import numpy as np
6
-
7
7
import wfdb
8
8
9
9
@@ -259,15 +259,25 @@ def test_2f(self):
259
259
record_EDF = wfdb .rdrecord ('sample-data/n16.edf' ).__dict__
260
260
261
261
fields = list (record_MIT .keys ())
262
- # MIT format method of checksum is outdated, sometimes the same value though
262
+ # Original MIT format method of checksum is outdated, sometimes
263
+ # the same value though
263
264
fields .remove ('checksum' )
265
+ # Original MIT format units are less comprehensive since they
266
+ # default to mV if unknown.. therefore added more default labels
267
+ fields .remove ('units' )
264
268
265
269
test_results = []
266
270
for field in fields :
267
271
# Signal value will be slightly off due to C to Python type conversion
268
272
if field == 'p_signal' :
269
- signal_diff = record_MIT [field ] - record_EDF [field ]
270
- if abs (max (signal_diff .min (), signal_diff .max (), key = abs )) <= 2 :
273
+ true_array = np .array (record_MIT [field ])
274
+ pred_array = np .array (record_EDF [field ])
275
+ sig_diff = np .abs ((pred_array - true_array ) / true_array )
276
+ sig_diff [sig_diff == - np .inf ] = 0
277
+ sig_diff [sig_diff == np .inf ] = 0
278
+ sig_diff = np .nanmean (sig_diff ,0 )
279
+ # 5% tolerance
280
+ if np .max (sig_diff ) <= 5 :
271
281
test_results .append (True )
272
282
else :
273
283
test_results .append (False )
@@ -293,15 +303,31 @@ def test_2g(self):
293
303
record_EDF = wfdb .rdrecord ('sample-data/SC4001E0-PSG.edf' ).__dict__
294
304
295
305
fields = list (record_MIT .keys ())
296
- # MIT format method of checksum is outdated, sometimes the same value though
306
+ # Original MIT format method of checksum is outdated, sometimes
307
+ # the same value though
297
308
fields .remove ('checksum' )
309
+ # Original MIT format units are less comprehensive since they
310
+ # default to mV if unknown.. therefore added more default labels
311
+ fields .remove ('units' )
312
+ # Initial value of signal will be off due to resampling done by
313
+ # MNE in the EDF reading phase
314
+ fields .remove ('init_value' )
315
+ # Samples per frame will be off due to resampling done by MNE in
316
+ # the EDF reading phase... I should probably fix this later
317
+ fields .remove ('samps_per_frame' )
298
318
299
319
test_results = []
300
320
for field in fields :
301
321
# Signal value will be slightly off due to C to Python type conversion
302
322
if field == 'p_signal' :
303
- signal_diff = record_MIT [field ] - record_EDF [field ]
304
- if abs (max (signal_diff .min (), signal_diff .max (), key = abs )) <= 2 :
323
+ true_array = np .array (record_MIT [field ])
324
+ pred_array = np .array (record_EDF [field ])
325
+ sig_diff = np .abs ((pred_array - true_array ) / true_array )
326
+ sig_diff [sig_diff == - np .inf ] = 0
327
+ sig_diff [sig_diff == np .inf ] = 0
328
+ sig_diff = np .nanmean (sig_diff ,0 )
329
+ # 5% tolerance
330
+ if np .max (sig_diff ) <= 5 :
305
331
test_results .append (True )
306
332
else :
307
333
test_results .append (False )
@@ -478,6 +504,13 @@ def test_4d(self):
478
504
479
505
assert np .array_equal (sig_round , sig_target )
480
506
507
+ def test_header_with_non_utf8 (self ):
508
+ """
509
+ Ignores non-utf8 characters in the header part.
510
+ """
511
+ record = wfdb .rdrecord ("sample-data/test_generator_2" )
512
+ sig_units_target = ['uV' , 'uV' , 'uV' , 'uV' , 'uV' , 'uV' , 'uV' , 'uV' , 'mV' , 'mV' , 'uV' , 'mV' ]
513
+ assert record .units .__eq__ (sig_units_target )
481
514
482
515
@classmethod
483
516
def tearDownClass (cls ):
0 commit comments