diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index c20c849689f02..0b3aa8f5ea7fb 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -18,6 +18,7 @@ from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC +from sklearn.base import clone import numpy as np from nose import SkipTest @@ -283,7 +284,8 @@ def test_countvectorizer_stop_words(): def test_countvectorizer_empty_vocabulary(): try: - CountVectorizer(vocabulary=[]) + vect = CountVectorizer(vocabulary=[]) + vect.fit(["foo"]) assert False, "we shouldn't get here" except ValueError as e: assert_in("empty vocabulary", str(e).lower()) @@ -440,10 +442,10 @@ def test_vectorizer(): # (equivalent to term count vectorizer + tfidf transformer) train_data = iter(ALL_FOOD_DOCS[:-1]) tv = TfidfVectorizer(norm='l1') - assert_false(tv.fixed_vocabulary) tv.max_df = v1.max_df tfidf2 = tv.fit_transform(train_data).toarray() + assert_false(tv.fixed_vocabulary_) assert_array_almost_equal(tfidf, tfidf2) # test the direct tfidf vectorizer with new data @@ -767,7 +769,7 @@ def test_vectorizer_pipeline_grid_selection(): best_vectorizer = grid_search.best_estimator_.named_steps['vect'] assert_equal(best_vectorizer.ngram_range, (1, 1)) assert_equal(best_vectorizer.norm, 'l2') - assert_false(best_vectorizer.fixed_vocabulary) + assert_false(best_vectorizer.fixed_vocabulary_) def test_vectorizer_pipeline_cross_validation(): @@ -777,7 +779,6 @@ def test_vectorizer_pipeline_cross_validation(): # label junk food as -1, the others as +1 target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS) - pipeline = Pipeline([('vect', TfidfVectorizer()), ('svc', LinearSVC())]) @@ -824,11 +825,10 @@ def test_tfidf_vectorizer_with_fixed_vocabulary(): # non regression smoke test for inheritance issues vocabulary = ['pizza', 'celeri'] vect = TfidfVectorizer(vocabulary=vocabulary) - assert_true(vect.fixed_vocabulary) X_1 = vect.fit_transform(ALL_FOOD_DOCS) X_2 = vect.transform(ALL_FOOD_DOCS) assert_array_almost_equal(X_1.toarray(), X_2.toarray()) - assert_true(vect.fixed_vocabulary) + assert_true(vect.fixed_vocabulary_) def test_pickling_vectorizer(): @@ -870,7 +870,8 @@ def test_pickling_transformer(): def test_non_unique_vocab(): vocab = ['a', 'b', 'c', 'a', 'a'] - assert_raises(ValueError, CountVectorizer, vocabulary=vocab) + vect = CountVectorizer(vocabulary=vocab) + assert_raises(ValueError, vect.fit, []) def test_hashingvectorizer_nan_in_docs(): @@ -901,3 +902,11 @@ def test_tfidfvectorizer_export_idf(): vect = TfidfVectorizer(use_idf=True) vect.fit(JUNK_FOOD_DOCS) assert_array_almost_equal(vect.idf_, vect._tfidf.idf_) + + +def test_vectorizer_vocab_clone(): + vect_vocab = TfidfVectorizer(vocabulary=["the"]) + vect_vocab_clone = clone(vect_vocab) + vect_vocab.fit(ALL_FOOD_DOCS) + vect_vocab_clone.fit(ALL_FOOD_DOCS) + assert_equal(vect_vocab_clone.vocabulary_, vect_vocab.vocabulary_) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index d5590f0604067..bce99f99eaf59 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -28,7 +28,8 @@ from ..preprocessing import normalize from .hashing import FeatureHasher from .stop_words import ENGLISH_STOP_WORDS -from sklearn.externals import six +from ..utils import deprecated +from ..externals import six __all__ = ['CountVectorizer', 'ENGLISH_STOP_WORDS', @@ -236,6 +237,38 @@ def build_analyzer(self): raise ValueError('%s is not a valid tokenization scheme/analyzer' % self.analyzer) + def _check_vocabulary(self): + vocabulary = self.vocabulary + if vocabulary is not None: + if not isinstance(vocabulary, Mapping): + vocab = {} + for i, t in enumerate(vocabulary): + if vocab.setdefault(t, i) != i: + msg = "Duplicate term in vocabulary: %r" % t + raise ValueError(msg) + vocabulary = vocab + else: + indices = set(six.itervalues(vocabulary)) + if len(indices) != len(vocabulary): + raise ValueError("Vocabulary contains repeated indices.") + for i in xrange(len(vocabulary)): + if i not in indices: + msg = ("Vocabulary of size %d doesn't contain index " + "%d." % (len(vocabulary), i)) + raise ValueError(msg) + if not vocabulary: + raise ValueError("empty vocabulary passed to fit") + self.fixed_vocabulary_ = True + self.vocabulary_ = dict(vocabulary) + else: + self.fixed_vocabulary_ = False + + @property + @deprecated("The `fixed_vocabulary` attribute is deprecated and will be " + "removed in 0.18. Please use `fixed_vocabulary_` instead.") + def fixed_vocabulary(self): + return self.fixed_vocabulary_ + class HashingVectorizer(BaseEstimator, VectorizerMixin): """Convert a collection of text documents to a matrix of token occurrences @@ -616,29 +649,7 @@ def __init__(self, input='content', encoding='utf-8', "max_features=%r, neither a positive integer nor None" % max_features) self.ngram_range = ngram_range - if vocabulary is not None: - if not isinstance(vocabulary, Mapping): - vocab = {} - for i, t in enumerate(vocabulary): - if vocab.setdefault(t, i) != i: - msg = "Duplicate term in vocabulary: %r" % t - raise ValueError(msg) - vocabulary = vocab - else: - indices = set(six.itervalues(vocabulary)) - if len(indices) != len(vocabulary): - raise ValueError("Vocabulary contains repeated indices.") - for i in xrange(len(vocabulary)): - if i not in indices: - msg = ("Vocabulary of size %d doesn't contain index " - "%d." % (len(vocabulary), i)) - raise ValueError(msg) - if not vocabulary: - raise ValueError("empty vocabulary passed to fit") - self.fixed_vocabulary = True - self.vocabulary_ = dict(vocabulary) - else: - self.fixed_vocabulary = False + self.vocabulary = vocabulary self.binary = binary self.dtype = dtype @@ -773,16 +784,18 @@ def fit_transform(self, raw_documents, y=None): # We intentionally don't call the transform method to make # fit_transform overridable without unwanted side effects in # TfidfVectorizer. + self._check_vocabulary() max_df = self.max_df min_df = self.min_df max_features = self.max_features - vocabulary, X = self._count_vocab(raw_documents, self.fixed_vocabulary) + vocabulary, X = self._count_vocab(raw_documents, + self.fixed_vocabulary_) if self.binary: X.data.fill(1) - if not self.fixed_vocabulary: + if not self.fixed_vocabulary_: X = self._sort_features(X, vocabulary) n_doc = X.shape[0] @@ -820,6 +833,9 @@ def transform(self, raw_documents): X : sparse matrix, [n_samples, n_features] Document-term matrix. """ + if not hasattr(self, 'vocabulary_'): + self._check_vocabulary() + if not hasattr(self, 'vocabulary_') or len(self.vocabulary_) == 0: raise ValueError("Vocabulary wasn't fitted or is empty!")