diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index d148d0b40c74f..f0e570ac70413 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -85,6 +85,14 @@ Changelog where ``max_features`` was sometimes rounded down to zero. :issue:`12388` by :user:`Connor Tann `. +:mod:`sklearn.feature_extraction` +........................... + +- |Fix| Fixed a regression in v0.20.0 where + :func:`feature_extraction.text.CountVectorizer` and other text vectorizers + could error during stop words validation with custom preprocessors + or tokenizers. :issue:`12393` by `Roman Yurchak`_. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 503c62b2b3de3..9798175e4d5bc 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals +import re import warnings import pytest @@ -1121,6 +1122,14 @@ def test_vectorizers_invalid_ngram_range(vec): ValueError, message, vec.transform, ["good news everyone"]) +def _check_stop_words_consistency(estimator): + stop_words = estimator.get_stop_words() + tokenize = estimator.build_tokenizer() + preprocess = estimator.build_preprocessor() + return estimator._check_stop_words_consistency(stop_words, preprocess, + tokenize) + + @fails_if_pypy def test_vectorizer_stop_words_inconsistent(): if PY2: @@ -1135,11 +1144,44 @@ def test_vectorizer_stop_words_inconsistent(): vec.set_params(stop_words=["you've", "you", "you'll", 'AND']) assert_warns_message(UserWarning, message, vec.fit_transform, ['hello world']) + # reset stop word validation + del vec._stop_words_id + assert _check_stop_words_consistency(vec) is False # Only one warning per stop list assert_no_warnings(vec.fit_transform, ['hello world']) + assert _check_stop_words_consistency(vec) is None # Test caching of inconsistency assessment vec.set_params(stop_words=["you've", "you", "you'll", 'blah', 'AND']) assert_warns_message(UserWarning, message, vec.fit_transform, ['hello world']) + + +@fails_if_pypy +@pytest.mark.parametrize('Estimator', + [CountVectorizer, TfidfVectorizer, HashingVectorizer]) +def test_stop_word_validation_custom_preprocessor(Estimator): + data = [{'text': 'some text'}] + + vec = Estimator() + assert _check_stop_words_consistency(vec) is True + + vec = Estimator(preprocessor=lambda x: x['text'], + stop_words=['and']) + assert _check_stop_words_consistency(vec) == 'error' + # checks are cached + assert _check_stop_words_consistency(vec) is None + vec.fit_transform(data) + + class CustomEstimator(Estimator): + def build_preprocessor(self): + return lambda x: x['text'] + + vec = CustomEstimator(stop_words=['and']) + assert _check_stop_words_consistency(vec) == 'error' + + vec = Estimator(tokenizer=lambda doc: re.compile(r'\w{1,}') + .findall(doc), + stop_words=['and']) + assert _check_stop_words_consistency(vec) is True diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 9d8c9bc2aa8b5..6120c1d4a8f3a 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -270,8 +270,22 @@ def get_stop_words(self): return _check_stop_list(self.stop_words) def _check_stop_words_consistency(self, stop_words, preprocess, tokenize): + """Check if stop words are consistent + + Returns + ------- + is_consistent : True if stop words are consistent with the preprocessor + and tokenizer, False if they are not, None if the check + was previously performed, "error" if it could not be + performed (e.g. because of the use of a custom + preprocessor / tokenizer) + """ + if id(self.stop_words) == getattr(self, '_stop_words_id', None): + # Stop words are were previously validated + return None + # NB: stop_words is validated, unlike self.stop_words - if id(self.stop_words) != getattr(self, '_stop_words_id', None): + try: inconsistent = set() for w in stop_words or (): tokens = list(tokenize(preprocess(w))) @@ -281,10 +295,16 @@ def _check_stop_words_consistency(self, stop_words, preprocess, tokenize): self._stop_words_id = id(self.stop_words) if inconsistent: - warnings.warn('Your stop_words may be inconsistent with your ' - 'preprocessing. Tokenizing the stop words ' - 'generated tokens %r not in stop_words.' % - sorted(inconsistent)) + warnings.warn('Your stop_words may be inconsistent with ' + 'your preprocessing. Tokenizing the stop ' + 'words generated tokens %r not in ' + 'stop_words.' % sorted(inconsistent)) + return not inconsistent + except Exception: + # Failed to check stop words consistency (e.g. because a custom + # preprocessor or tokenizer was used) + self._stop_words_id = id(self.stop_words) + return 'error' def build_analyzer(self): """Return a callable that handles preprocessing and tokenization"""