-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Support for 64 bit sparse array indices in text vectorizers #9147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d0787de
200fac0
1278bbc
564f8b7
f6a7d0d
b230cfd
7845f6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Author: Lars Buitinck | ||
# License: BSD 3 clause | ||
|
||
import sys | ||
import array | ||
from cpython cimport array | ||
cimport cython | ||
|
@@ -9,6 +10,7 @@ cimport numpy as np | |
import numpy as np | ||
|
||
from sklearn.utils.murmurhash cimport murmurhash3_bytes_s32 | ||
from sklearn.utils.fixes import sp_version | ||
|
||
np.import_array() | ||
|
||
|
@@ -33,12 +35,20 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1): | |
cdef array.array indices | ||
cdef array.array indptr | ||
indices = array.array("i") | ||
indptr = array.array("i", [0]) | ||
if sys.version_info >= (3, 3): | ||
indices_array_dtype = "q" | ||
indices_np_dtype = np.longlong | ||
else: | ||
# On Windows with PY2.7 long int would still correspond to 32 bit. | ||
indices_array_dtype = "l" | ||
indices_np_dtype = np.int_ | ||
|
||
indptr = array.array(indices_array_dtype, [0]) | ||
|
||
# Since Python array does not understand Numpy dtypes, we grow the indices | ||
# and values arrays ourselves. Use a Py_ssize_t capacity for safety. | ||
cdef Py_ssize_t capacity = 8192 # arbitrary | ||
cdef np.int32_t size = 0 | ||
cdef np.int64_t size = 0 | ||
cdef np.ndarray values = np.empty(capacity, dtype=dtype) | ||
|
||
for x in raw_X: | ||
|
@@ -79,4 +89,18 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1): | |
indptr[len(indptr) - 1] = size | ||
|
||
indices_a = np.frombuffer(indices, dtype=np.int32) | ||
return (indices_a, np.frombuffer(indptr, dtype=np.int32), values[:size]) | ||
indptr_a = np.frombuffer(indptr, dtype=indices_np_dtype) | ||
|
||
if indptr[-1] > 2147483648: # = 2**31 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would sort of be nice if this were refactored somewhere, but I can't think of somewhere both pleasant and useful to keep it shared. |
||
if sp_version < (0, 14): | ||
raise ValueError(('sparse CSR array has {} non-zero ' | ||
'elements and requires 64 bit indexing, ' | ||
' which is unsupported with scipy {}. ' | ||
'Please upgrade to scipy >=0.14') | ||
.format(indptr[-1], '.'.join(sp_version))) | ||
# both indices and indptr have the same dtype in CSR arrays | ||
indices_a = indices_a.astype(np.int64) | ||
else: | ||
indptr_a = indptr_a.astype(np.int32) | ||
|
||
return (indices_a, indptr_a, values[:size]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from .hashing import FeatureHasher | ||
from .stop_words import ENGLISH_STOP_WORDS | ||
from ..utils.validation import check_is_fitted | ||
from ..utils.fixes import sp_version | ||
|
||
__all__ = ['CountVectorizer', | ||
'ENGLISH_STOP_WORDS', | ||
|
@@ -784,7 +785,8 @@ def _count_vocab(self, raw_documents, fixed_vocab): | |
|
||
analyze = self.build_analyzer() | ||
j_indices = [] | ||
indptr = _make_int_array() | ||
indptr = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked (in the above-linked benchmark) that switching from an |
||
|
||
values = _make_int_array() | ||
indptr.append(0) | ||
for doc in raw_documents: | ||
|
@@ -811,8 +813,20 @@ def _count_vocab(self, raw_documents, fixed_vocab): | |
raise ValueError("empty vocabulary; perhaps the documents only" | ||
" contain stop words") | ||
|
||
j_indices = np.asarray(j_indices, dtype=np.intc) | ||
indptr = np.frombuffer(indptr, dtype=np.intc) | ||
if indptr[-1] > 2147483648: # = 2**31 - 1 | ||
if sp_version >= (0, 14): | ||
indices_dtype = np.int64 | ||
else: | ||
raise ValueError(('sparse CSR array has {} non-zero ' | ||
'elements and requires 64 bit indexing, ' | ||
' which is unsupported with scipy {}. ' | ||
'Please upgrade to scipy >=0.14') | ||
.format(indptr[-1], '.'.join(sp_version))) | ||
|
||
else: | ||
indices_dtype = np.int32 | ||
j_indices = np.asarray(j_indices, dtype=indices_dtype) | ||
indptr = np.asarray(indptr, dtype=indices_dtype) | ||
values = np.frombuffer(values, dtype=np.intc) | ||
|
||
X = sp.csr_matrix((values, j_indices, indptr), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just use
np.intp
for all cases here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that I don't know the
array.array
dtype corresponding tonp.intp
. Both need to match in all cases since we are usingnp.frombuffer
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.dtype(np.intp).char
will give you that.