Skip to content

Commit 9b045d4

Browse files
TwsThomasNicolasHug
authored andcommitted
MNT Add seed parameter to private FeatureHasher transform helper (scikit-learn#14605)
1 parent c6b5eb2 commit 9b045d4

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

sklearn/feature_extraction/_hashing.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ from ..utils.fixes import sp_version
1717
np.import_array()
1818

1919

20-
def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
20+
def transform(raw_X, Py_ssize_t n_features, dtype,
21+
bint alternate_sign=1, unsigned int seed=0):
2122
"""Guts of FeatureHasher.transform.
2223
2324
Returns
@@ -65,7 +66,7 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
6566
elif not isinstance(f, bytes):
6667
raise TypeError("feature names must be strings")
6768

68-
h = murmurhash3_bytes_s32(<bytes>f, 0)
69+
h = murmurhash3_bytes_s32(<bytes>f, seed)
6970

7071
array.resize_smart(indices, len(indices) + 1)
7172
indices[len(indices) - 1] = abs(h) % n_features

sklearn/feature_extraction/hashing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def transform(self, raw_X):
150150
raw_X = (((f, 1) for f in x) for x in raw_X)
151151
indices, indptr, values = \
152152
_hashing_transform(raw_X, self.n_features, self.dtype,
153-
self.alternate_sign)
153+
self.alternate_sign, seed=0)
154154
n_samples = indptr.shape[0] - 1
155155

156156
if n_samples == 0:

sklearn/feature_extraction/tests/test_feature_hasher.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from sklearn.feature_extraction import FeatureHasher
7+
from sklearn.feature_extraction._hashing import transform as _hashing_transform
78
from sklearn.utils.testing import (ignore_warnings,
89
fails_if_pypy)
910

@@ -45,6 +46,28 @@ def test_feature_hasher_strings():
4546
assert X.nnz == 6
4647

4748

49+
def test_hashing_transform_seed():
50+
# check the influence of the seed when computing the hashes
51+
raw_X = [["foo", "bar", "baz", "foo".encode("ascii")],
52+
["bar".encode("ascii"), "baz", "quux"]]
53+
54+
raw_X_ = (((f, 1) for f in x) for x in raw_X)
55+
indices, indptr, _ = _hashing_transform(raw_X_, 2 ** 7, str,
56+
False)
57+
58+
raw_X_ = (((f, 1) for f in x) for x in raw_X)
59+
indices_0, indptr_0, _ = _hashing_transform(raw_X_, 2 ** 7, str,
60+
False, seed=0)
61+
assert_array_equal(indices, indices_0)
62+
assert_array_equal(indptr, indptr_0)
63+
64+
raw_X_ = (((f, 1) for f in x) for x in raw_X)
65+
indices_1, _, _ = _hashing_transform(raw_X_, 2 ** 7, str,
66+
False, seed=1)
67+
with pytest.raises(AssertionError):
68+
assert_array_equal(indices, indices_1)
69+
70+
4871
def test_feature_hasher_pairs():
4972
raw_X = (iter(d.items()) for d in [{"foo": 1, "bar": 2},
5073
{"baz": 3, "quux": 4, "foo": -1}])

0 commit comments

Comments
 (0)