Skip to content

FIX raise error for max_df and min_df greater than 1 in Vectorizer #20752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ Changelog
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.ensemble`
...........................
.......................

- |Fix| Fixed a bug that could produce a segfault in rare cases for
:class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor`.
:pr:`21130` :user:`Christian Lorentzen <lorentzenchr>`.

:mod:`sklearn.feature_extraction`
.................................

- |Fix| Fixed a bug in :class:`feature_extraction.CountVectorizer` and
:class:`feature_extraction.TfidfVectorizer` by raising an
error when 'min_idf' or 'max_idf' are floating-point numbers greater than 1.
:pr:`20752` by :user:`Alek Lefebvre <AlekLefebvre>`.

:mod:`sklearn.linear_model`
...........................

Expand Down
25 changes: 25 additions & 0 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,31 @@ def test_vectorizer_min_df():
assert len(vect.stop_words_) == 5


@pytest.mark.parametrize(
"params, err_type, message",
(
({"max_df": 2.0}, ValueError, "max_df == 2.0, must be <= 1.0."),
({"min_df": 1.5}, ValueError, "min_df == 1.5, must be <= 1.0."),
({"max_df": -2}, ValueError, "max_df == -2, must be >= 0."),
({"min_df": -10}, ValueError, "min_df == -10, must be >= 0."),
({"min_df": 3, "max_df": 2.0}, ValueError, "max_df == 2.0, must be <= 1.0."),
({"min_df": 1.5, "max_df": 50}, ValueError, "min_df == 1.5, must be <= 1.0."),
({"max_features": -10}, ValueError, "max_features == -10, must be >= 0."),
(
{"max_features": 3.5},
TypeError,
"max_features must be an instance of <class 'numbers.Integral'>, not <class"
" 'float'>",
),
),
)
def test_vectorizer_params_validation(params, err_type, message):
with pytest.raises(err_type, match=message):
test_data = ["abc", "dea", "eat"]
vect = CountVectorizer(**params, analyzer="char")
vect.fit(test_data)


# TODO: Remove in 1.2 when get_feature_names is removed.
@pytest.mark.filterwarnings("ignore::FutureWarning:sklearn")
@pytest.mark.parametrize("get_names", ["get_feature_names", "get_feature_names_out"])
Expand Down
27 changes: 18 additions & 9 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..preprocessing import normalize
from ._hash import FeatureHasher
from ._stop_words import ENGLISH_STOP_WORDS
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES, check_scalar
from ..utils.deprecation import deprecated
from ..utils import _IS_32BIT
from ..utils.fixes import _astype_copy_false
Expand Down Expand Up @@ -1120,15 +1120,7 @@ def __init__(
self.stop_words = stop_words
self.max_df = max_df
self.min_df = min_df
if max_df < 0 or min_df < 0:
raise ValueError("negative value for max_df or min_df")
self.max_features = max_features
if max_features is not None:
if not isinstance(max_features, numbers.Integral) or max_features <= 0:
raise ValueError(
"max_features=%r, neither a positive integer nor None"
% max_features
)
self.ngram_range = ngram_range
self.vocabulary = vocabulary
self.binary = binary
Expand Down Expand Up @@ -1265,6 +1257,23 @@ def _count_vocab(self, raw_documents, fixed_vocab):
X.sort_indices()
return vocabulary, X

def _validate_params(self):
"""Validation of min_df, max_df and max_features"""
super()._validate_params()

if self.max_features is not None:
check_scalar(self.max_features, "max_features", numbers.Integral, min_val=0)

if isinstance(self.min_df, numbers.Integral):
check_scalar(self.min_df, "min_df", numbers.Integral, min_val=0)
else:
check_scalar(self.min_df, "min_df", numbers.Real, min_val=0.0, max_val=1.0)

if isinstance(self.max_df, numbers.Integral):
check_scalar(self.max_df, "max_df", numbers.Integral, min_val=0)
else:
check_scalar(self.max_df, "max_df", numbers.Real, min_val=0.0, max_val=1.0)

def fit(self, raw_documents, y=None):
"""Learn a vocabulary dictionary of all tokens in the raw documents.

Expand Down