22
22
import re
23
23
import contextlib
24
24
from collections .abc import Iterable
25
+ from collections .abc import Sequence
25
26
26
27
import scipy as sp
27
28
from functools import wraps
60
61
check_is_fitted ,
61
62
check_X_y ,
62
63
)
64
+ from sklearn .utils .fixes import threadpool_info
63
65
64
66
65
67
__all__ = [
@@ -602,6 +604,38 @@ def __exit__(self, exc_type, exc_val, exc_tb):
602
604
_delete_folder (self .temp_folder )
603
605
604
606
607
+ def _create_memmap_backed_array (array , filename , mmap_mode ):
608
+ # https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
609
+ fp = np .memmap (filename , dtype = array .dtype , mode = "w+" , shape = array .shape )
610
+ fp [:] = array [:] # write array to memmap array
611
+ fp .flush ()
612
+ memmap_backed_array = np .memmap (
613
+ filename , dtype = array .dtype , mode = mmap_mode , shape = array .shape
614
+ )
615
+ return memmap_backed_array
616
+
617
+
618
+ def _create_aligned_memmap_backed_arrays (data , mmap_mode , folder ):
619
+ if isinstance (data , np .ndarray ):
620
+ filename = op .join (folder , "data.dat" )
621
+ return _create_memmap_backed_array (data , filename , mmap_mode )
622
+
623
+ if isinstance (data , Sequence ) and all (
624
+ isinstance (each , np .ndarray ) for each in data
625
+ ):
626
+ return [
627
+ _create_memmap_backed_array (
628
+ array , op .join (folder , f"data{ index } .dat" ), mmap_mode
629
+ )
630
+ for index , array in enumerate (data )
631
+ ]
632
+
633
+ raise ValueError (
634
+ "When creating aligned memmap-backed arrays, input must be a single array or a"
635
+ " sequence of arrays"
636
+ )
637
+
638
+
605
639
def create_memmap_backed_data (data , mmap_mode = "r" , return_folder = False , aligned = False ):
606
640
"""
607
641
Parameters
@@ -616,18 +650,23 @@ def create_memmap_backed_data(data, mmap_mode="r", return_folder=False, aligned=
616
650
"""
617
651
temp_folder = tempfile .mkdtemp (prefix = "sklearn_testing_" )
618
652
atexit .register (functools .partial (_delete_folder , temp_folder , warn = True ))
653
+ # OpenBLAS is known to segfault with unaligned data on the Prescott
654
+ # architecture so force aligned=True on Prescott. For more details, see:
655
+ # https://github.com/scipy/scipy/issues/14886
656
+ has_prescott_openblas = any (
657
+ True
658
+ for info in threadpool_info ()
659
+ if info ["internal_api" ] == "openblas"
660
+ # Prudently assume Prescott might be the architecture if it is unknown.
661
+ and info .get ("architecture" , "prescott" ).lower () == "prescott"
662
+ )
663
+ if has_prescott_openblas :
664
+ aligned = True
665
+
619
666
if aligned :
620
- if isinstance (data , np .ndarray ) and data .flags .aligned :
621
- # https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
622
- filename = op .join (temp_folder , "data.dat" )
623
- fp = np .memmap (filename , dtype = data .dtype , mode = "w+" , shape = data .shape )
624
- fp [:] = data [:] # write data to memmap array
625
- fp .flush ()
626
- memmap_backed_data = np .memmap (
627
- filename , dtype = data .dtype , mode = mmap_mode , shape = data .shape
628
- )
629
- else :
630
- raise ValueError ("If aligned=True, input must be a single numpy array." )
667
+ memmap_backed_data = _create_aligned_memmap_backed_arrays (
668
+ data , mmap_mode , temp_folder
669
+ )
631
670
else :
632
671
filename = op .join (temp_folder , "data.pkl" )
633
672
joblib .dump (data , filename )
0 commit comments