diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 444895245bf6b..f02e91782e07d 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1208,9 +1208,10 @@ Model validation :template: class.rst naive_bayes.BernoulliNB + naive_bayes.CategoricalNB + naive_bayes.ComplementNB naive_bayes.GaussianNB naive_bayes.MultinomialNB - naive_bayes.ComplementNB .. _neighbors_ref: diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 1ba870c3b8bfc..457ec6c630b99 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -224,6 +224,40 @@ It is advisable to evaluate both models, if time permits. `_ 3rd Conf. on Email and Anti-Spam (CEAS). +.. _categorical_naive_bayes: + +Categorical Naive Bayes +----------------------- + +:class:`CategoricalNB` implements the categorical naive Bayes +algorithm for categorically distributed data. It assumes that each feature, +which is described by the index :math:`i`, has its own categorical +distribution. + +For each feature :math:`i` in the training set :math:`X`, +:class:`CategoricalNB` estimates a categorical distribution for each feature i +of X conditioned on the class y. The index set of the samples is defined as +:math:`J = \{ 1, \dots, m \}`, with :math:`m` as the number of samples. + +The probability of category :math:`t` in feature :math:`i` given class +:math:`c` is estimated as: + +.. math:: + + P(x_i = t \mid y = c \: ;\, \alpha) = \frac{ N_{tic} + \alpha}{N_{c} + + \alpha n_i}, + +where :math:`N_{tic} = |\{j \in J \mid x_{ij} = t, y_j = c\}|` is the number +of times category :math:`t` appears in the samples :math:`x_{i}`, which belong +to class :math:`c`, :math:`N_{c} = |\{ j \in J\mid y_j = c\}|` is the number +of samples with class c, :math:`\alpha` is a smoothing parameter and +:math:`n_i` is the number of available categories of feature :math:`i`. + +:class:`CategoricalNB` assumes that the sample matrix :math:`X` is encoded +(for instance with the help of :class:`OrdinalEncoder`) such that all +categories for each feature :math:`i` are represented with numbers +:math:`0, ..., n_i - 1` where :math:`n_i` is the number of available categories +of feature :math:`i`. Out-of-core naive Bayes model fitting ------------------------------------- diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index b3b301d6dece4..84cf954fbdffc 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -419,6 +419,14 @@ Changelog - |Fix| :class:`multioutput.MultiOutputClassifier` now has attribute ``classes_``. :pr:`14629` by :user:`Agamemnon Krasoulis `. +:mod:`sklearn.naive_bayes` +............................... + +- |MajorFeature| Added :class:`naive_bayes.CategoricalNB` that implements the + Categorical Naive Bayes classifier. + :pr:`12569` by :user:`Tim Bicker ` and + :user:`Florian Wilhelm `. + :mod:`sklearn.neighbors` ........................ diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d1bb360986c22..6b5c6fdcacf3a 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod + import numpy as np from scipy.sparse import issparse @@ -32,7 +33,8 @@ from .utils.multiclass import _check_partial_fit_first_call from .utils.validation import check_is_fitted, check_non_negative -__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB'] +__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB', + 'CategoricalNB'] class BaseNB(ClassifierMixin, BaseEstimator, metaclass=ABCMeta): @@ -49,6 +51,12 @@ def _joint_log_likelihood(self, X): predict_proba and predict_log_proba. """ + @abstractmethod + def _check_X(self, X): + """Validate input X + """ + pass + def predict(self, X): """ Perform classification on an array of test vectors X. @@ -62,6 +70,8 @@ def predict(self, X): C : array, shape = [n_samples] Predicted target values for X """ + check_is_fitted(self) + X = self._check_X(X) jll = self._joint_log_likelihood(X) return self.classes_[np.argmax(jll, axis=1)] @@ -80,6 +90,8 @@ def predict_log_proba(self, X): the model. The columns correspond to the classes in sorted order, as they appear in the attribute :term:`classes_`. """ + check_is_fitted(self) + X = self._check_X(X) jll = self._joint_log_likelihood(X) # normalize by P(x) = P(f_1, ..., f_n) log_prob_x = logsumexp(jll, axis=1) @@ -192,10 +204,12 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y(X, y) return self._partial_fit(X, y, np.unique(y), _refit=True, sample_weight=sample_weight) + def _check_X(self, X): + return check_array(X) + @staticmethod def _update_mean_variance(n_past, mu, var, X, sample_weight=None): """Compute online update of Gaussian mean and variance. @@ -431,9 +445,6 @@ def _partial_fit(self, X, y, classes=None, _refit=False, return self def _joint_log_likelihood(self, X): - check_is_fitted(self) - - X = check_array(X) joint_log_likelihood = [] for i in range(np.size(self.classes_)): jointi = np.log(self.class_prior_[i]) @@ -458,6 +469,12 @@ class BaseDiscreteNB(BaseNB): _joint_log_likelihood(X) as per BaseNB """ + def _check_X(self, X): + return check_array(X, accept_sparse='csr') + + def _check_X_y(self, X, y): + return check_X_y(X, y, accept_sparse='csr') + def _update_class_log_prior(self, class_prior=None): n_classes = len(self.classes_) if class_prior is not None: @@ -483,7 +500,7 @@ def _check_alpha(self): raise ValueError('Smoothing parameter alpha = %.1e. ' 'alpha should be > 0.' % np.min(self.alpha)) if isinstance(self.alpha, np.ndarray): - if not self.alpha.shape[0] == self.feature_count_.shape[1]: + if not self.alpha.shape[0] == self.n_features_: raise ValueError("alpha should be a scalar or a numpy array " "with shape [n_features]") if np.min(self.alpha) < _ALPHA_MIN: @@ -528,19 +545,18 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): ------- self : object """ - X = check_array(X, accept_sparse='csr', dtype=np.float64) + X, y = self._check_X_y(X, y) _, n_features = X.shape if _check_partial_fit_first_call(self, classes): # This is the first call to partial_fit: # initialize various cumulative counters n_effective_classes = len(classes) if len(classes) > 1 else 2 - self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64) - self.feature_count_ = np.zeros((n_effective_classes, n_features), - dtype=np.float64) - elif n_features != self.coef_.shape[1]: + self._init_counters(n_effective_classes, n_features) + self.n_features_ = n_features + elif n_features != self.n_features_: msg = "Number of features %d does not match previous data %d." - raise ValueError(msg % (n_features, self.coef_.shape[-1])) + raise ValueError(msg % (n_features, self.n_features_)) Y = label_binarize(y, classes=self.classes_) if Y.shape[1] == 1: @@ -591,8 +607,9 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y(X, y, 'csr') + X, y = self._check_X_y(X, y) _, n_features = X.shape + self.n_features_ = n_features labelbin = LabelBinarizer() Y = labelbin.fit_transform(y) @@ -603,8 +620,8 @@ def fit(self, X, y, sample_weight=None): # LabelBinarizer().fit_transform() returns arrays with dtype=np.int64. # We convert it to np.float64 to support sample_weight consistently; # this means we also don't have to cast X to floating point - Y = Y.astype(np.float64, copy=False) if sample_weight is not None: + Y = Y.astype(np.float64, copy=False) sample_weight = np.atleast_2d(sample_weight) Y *= check_array(sample_weight).T @@ -613,15 +630,19 @@ def fit(self, X, y, sample_weight=None): # Count raw events from data before updating the class log prior # and feature log probas n_effective_classes = Y.shape[1] - self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64) - self.feature_count_ = np.zeros((n_effective_classes, n_features), - dtype=np.float64) + + self._init_counters(n_effective_classes, n_features) self._count(X, Y) alpha = self._check_alpha() self._update_feature_log_prob(alpha) self._update_class_log_prior(class_prior=class_prior) return self + def _init_counters(self, n_effective_classes, n_features): + self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64) + self.feature_count_ = np.zeros((n_effective_classes, n_features), + dtype=np.float64) + # XXX The following is a stopgap measure; we need to set the dimensions # of class_log_prior_ and feature_log_prob_ correctly. def _get_coef(self): @@ -693,13 +714,17 @@ class MultinomialNB(BaseDiscreteNB): during fitting. This value is weighted by the sample weight when provided. + n_features_ : int + Number of features of each sample. + classes_ : array-like, shape (n_classes,) Unique class labels. Examples -------- >>> import numpy as np - >>> X = np.random.randint(5, size=(6, 100)) + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) >>> y = np.array([1, 2, 3, 4, 5, 6]) >>> from sklearn.naive_bayes import MultinomialNB >>> clf = MultinomialNB() @@ -745,9 +770,6 @@ def _update_feature_log_prob(self, alpha): def _joint_log_likelihood(self, X): """Calculate the posterior log probability of the samples X""" - check_is_fitted(self) - - X = check_array(X, accept_sparse='csr') return (safe_sparse_dot(X, self.feature_log_prob_.T) + self.class_log_prior_) @@ -798,6 +820,9 @@ class ComplementNB(BaseDiscreteNB): Number of samples encountered for each (class, feature) during fitting. This value is weighted by the sample weight when provided. + n_features_ : int + Number of features of each sample. + feature_all_ : array, shape (n_features,) Number of samples encountered for each feature during fitting. This value is weighted by the sample weight when provided. @@ -808,7 +833,8 @@ class ComplementNB(BaseDiscreteNB): Examples -------- >>> import numpy as np - >>> X = np.random.randint(5, size=(6, 100)) + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) >>> y = np.array([1, 2, 3, 4, 5, 6]) >>> from sklearn.naive_bayes import ComplementNB >>> clf = ComplementNB() @@ -856,9 +882,6 @@ def _update_feature_log_prob(self, alpha): def _joint_log_likelihood(self, X): """Calculate the class scores for the samples in X.""" - check_is_fitted(self) - - X = check_array(X, accept_sparse="csr") jll = safe_sparse_dot(X, self.feature_log_prob_.T) if len(self.classes_) == 1: jll += self.class_log_prior_ @@ -912,14 +935,17 @@ class BernoulliNB(BaseDiscreteNB): during fitting. This value is weighted by the sample weight when provided. + n_features_ : int + Number of features of each sample. + classes_ : array of shape = [n_classes] The classes labels. - Examples -------- >>> import numpy as np - >>> X = np.random.randint(2, size=(6, 100)) + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) >>> Y = np.array([1, 2, 3, 4, 4, 5]) >>> from sklearn.naive_bayes import BernoulliNB >>> clf = BernoulliNB() @@ -950,10 +976,20 @@ def __init__(self, alpha=1.0, binarize=.0, fit_prior=True, self.fit_prior = fit_prior self.class_prior = class_prior - def _count(self, X, Y): - """Count and smooth feature occurrences.""" + def _check_X(self, X): + X = super()._check_X(X) if self.binarize is not None: X = binarize(X, threshold=self.binarize) + return X + + def _check_X_y(self, X, y): + X, y = super()._check_X_y(X, y) + if self.binarize is not None: + X = binarize(X, threshold=self.binarize) + return X, y + + def _count(self, X, Y): + """Count and smooth feature occurrences.""" self.feature_count_ += safe_sparse_dot(Y.T, X) self.class_count_ += Y.sum(axis=0) @@ -967,13 +1003,6 @@ def _update_feature_log_prob(self, alpha): def _joint_log_likelihood(self, X): """Calculate the posterior log probability of the samples X""" - check_is_fitted(self) - - X = check_array(X, accept_sparse='csr') - - if self.binarize is not None: - X = binarize(X, threshold=self.binarize) - n_classes, n_features = self.feature_log_prob_.shape n_samples, n_features_X = X.shape @@ -987,3 +1016,212 @@ def _joint_log_likelihood(self, X): jll += self.class_log_prior_ + neg_prob.sum(axis=1) return jll + + +class CategoricalNB(BaseDiscreteNB): + """Naive Bayes classifier for categorical features + + The categorical Naive Bayes classifier is suitable for classification with + discrete features that are categorically distributed. The categories of + each feature are drawn from a categorical distribution. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + alpha : float, optional (default=1.0) + Additive (Laplace/Lidstone) smoothing parameter + (0 for no smoothing). + + fit_prior : boolean, optional (default=True) + Whether to learn class prior probabilities or not. + If false, a uniform prior will be used. + + class_prior : array-like, size (n_classes,), optional (default=None) + Prior probabilities of the classes. If specified the priors are not + adjusted according to the data. + + Attributes + ---------- + class_log_prior_ : array, shape (n_classes, ) + Smoothed empirical log probability for each class. + + feature_log_prob_ : list of arrays, len n_features + Holds arrays of shape (n_classes, n_categories of respective feature) + for each feature. Each array provides the empirical log probability + of categories given the respective feature and class, ``P(x_i|y)``. + + class_count_ : array, shape (n_classes,) + Number of samples encountered for each class during fitting. This + value is weighted by the sample weight when provided. + + category_count_ : list of arrays, len n_features + Holds arrays of shape (n_classes, n_categories of respective feature) + for each feature. Each array provides the number of samples + encountered for each class and category of the specific feature. + + n_features_ : int + Number of features of each sample. + + Examples + -------- + >>> import numpy as np + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) + >>> y = np.array([1, 2, 3, 4, 5, 6]) + >>> from sklearn.naive_bayes import CategoricalNB + >>> clf = CategoricalNB() + >>> clf.fit(X, y) + CategoricalNB() + >>> print(clf.predict(X[2:3])) + [3] + """ + + def __init__(self, alpha=1.0, fit_prior=True, class_prior=None): + self.alpha = alpha + self.fit_prior = fit_prior + self.class_prior = class_prior + + def fit(self, X, y, sample_weight=None): + """Fit Naive Bayes classifier according to X, y + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = [n_samples, n_features] + Training vectors, where n_samples is the number of samples and + n_features is the number of features. Here, each feature of X is + assumed to be from a different categorical distribution. + It is further assumed that all categories of each feature are + represented by the numbers 0, ..., n - 1, where n refers to the + total number of categories for the given feature. This can, for + instance, be achieved with the help of OrdinalEncoder. + + y : array-like, shape = [n_samples] + Target values. + + sample_weight : array-like, shape = [n_samples], (default=None) + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + """ + return super().fit(X, y, sample_weight=sample_weight) + + def partial_fit(self, X, y, classes=None, sample_weight=None): + """Incremental fit on a batch of samples. + + This method is expected to be called several times consecutively + on different chunks of a dataset so as to implement out-of-core + or online learning. + + This is especially useful when the whole dataset is too big to fit in + memory at once. + + This method has some performance overhead hence it is better to call + partial_fit on chunks of data that are as large as possible + (as long as fitting in the memory budget) to hide the overhead. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = [n_samples, n_features] + Training vectors, where n_samples is the number of samples and + n_features is the number of features. Here, each feature of X is + assumed to be from a different categorical distribution. + It is further assumed that all categories of each feature are + represented by the numbers 0, ..., n - 1, where n refers to the + total number of categories for the given feature. This can, for + instance, be achieved with the help of OrdinalEncoder. + + y : array-like, shape = [n_samples] + Target values. + + classes : array-like, shape = [n_classes] (default=None) + List of all the classes that can possibly appear in the y vector. + + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like, shape = [n_samples], (default=None) + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + """ + return super().partial_fit(X, y, classes, + sample_weight=sample_weight) + + def _check_X(self, X): + # FIXME: we can avoid calling check_array twice after #14872 is merged. + # X = check_array(X, y, dtype='int', accept_sparse=False, + # force_all_finite=True) + X = check_array(X, accept_sparse=False, force_all_finite=True) + X = check_array(X, dtype='int') + if np.any(X < 0): + raise ValueError("X must not contain negative values.") + return X + + def _check_X_y(self, X, y): + # FIXME: we can avoid calling check_array twice after #14872 is merged. + # X, y = check_array(X, y, dtype='int', accept_sparse=False, + # force_all_finite=True) + X, y = check_X_y(X, y, accept_sparse=False, force_all_finite=True) + X, y = check_X_y(X, y, dtype='int') + if np.any(X < 0): + raise ValueError("X must not contain negative values.") + return X, y + + def _init_counters(self, n_effective_classes, n_features): + self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64) + self.category_count_ = [np.zeros((n_effective_classes, 0)) + for _ in range(n_features)] + + def _count(self, X, Y): + def _update_cat_count_dims(cat_count, highest_feature): + diff = highest_feature + 1 - cat_count.shape[1] + if diff > 0: + # we append a column full of zeros for each new category + return np.pad(cat_count, [(0, 0), (0, diff)], 'constant') + return cat_count + + def _update_cat_count(X_feature, Y, cat_count, n_classes): + for j in range(n_classes): + mask = Y[:, j].astype(bool) + if Y.dtype.type == np.int64: + weights = None + else: + weights = Y[mask, j] + counts = np.bincount(X_feature[mask], weights=weights) + indices = np.nonzero(counts)[0] + cat_count[j, indices] += counts[indices] + + self.class_count_ += Y.sum(axis=0) + for i in range(self.n_features_): + X_feature = X[:, i] + self.category_count_[i] = _update_cat_count_dims( + self.category_count_[i], X_feature.max()) + _update_cat_count(X_feature, Y, + self.category_count_[i], + self.class_count_.shape[0]) + + def _update_feature_log_prob(self, alpha): + feature_log_prob = [] + for i in range(self.n_features_): + smoothed_cat_count = self.category_count_[i] + alpha + smoothed_class_count = smoothed_cat_count.sum(axis=1) + feature_log_prob.append( + np.log(smoothed_cat_count) - + np.log(smoothed_class_count.reshape(-1, 1))) + self.feature_log_prob_ = feature_log_prob + + def _joint_log_likelihood(self, X): + if not X.shape[1] == self.n_features_: + raise ValueError("Expected input with %d features, got %d instead" + .format(self.n_features_, X.shape[1])) + jll = np.zeros((X.shape[0], self.class_count_.shape[0])) + for i in range(self.n_features_): + indices = X[:, i] + jll += self.feature_log_prob_[i][:, indices].T + total_ll = jll + self.class_log_prior_ + return total_ll diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 60770041e0dfd..0eee076bd91e5 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -20,6 +20,7 @@ from sklearn.naive_bayes import GaussianNB, BernoulliNB from sklearn.naive_bayes import MultinomialNB, ComplementNB +from sklearn.naive_bayes import CategoricalNB # Data is just 6 separable points in the plane X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) @@ -191,7 +192,7 @@ def test_gnb_naive_bayes_scale_invariance(): assert_array_equal(labels[1], labels[2]) -@pytest.mark.parametrize("cls", [MultinomialNB, BernoulliNB]) +@pytest.mark.parametrize("cls", [MultinomialNB, BernoulliNB, CategoricalNB]) def test_discretenb_prior(cls): # Test whether class priors are properly set. clf = cls().fit(X2, y2) @@ -199,7 +200,7 @@ def test_discretenb_prior(cls): clf.class_log_prior_, 8) -@pytest.mark.parametrize("cls", [MultinomialNB, BernoulliNB]) +@pytest.mark.parametrize("cls", [MultinomialNB, BernoulliNB, CategoricalNB]) def test_discretenb_partial_fit(cls): clf1 = cls() clf1.fit([[0, 1], [1, 0], [1, 1]], [0, 1, 1]) @@ -207,15 +208,48 @@ def test_discretenb_partial_fit(cls): clf2 = cls() clf2.partial_fit([[0, 1], [1, 0], [1, 1]], [0, 1, 1], classes=[0, 1]) assert_array_equal(clf1.class_count_, clf2.class_count_) + if cls is CategoricalNB: + for i in range(len(clf1.category_count_)): + assert_array_equal(clf1.category_count_[i], + clf2.category_count_[i]) + else: + assert_array_equal(clf1.feature_count_, clf2.feature_count_) clf3 = cls() + # all categories have to appear in the first partial fit clf3.partial_fit([[0, 1]], [0], classes=[0, 1]) clf3.partial_fit([[1, 0]], [1]) clf3.partial_fit([[1, 1]], [1]) assert_array_equal(clf1.class_count_, clf3.class_count_) - - -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, GaussianNB]) + if cls is CategoricalNB: + # the categories for each feature of CategoricalNB are mapped to an + # index chronologically with each call of partial fit and therefore + # the category_count matrices cannot be compared for equality + for i in range(len(clf1.category_count_)): + assert_array_equal(clf1.category_count_[i].shape, + clf3.category_count_[i].shape) + assert_array_equal(np.sum(clf1.category_count_[i], axis=1), + np.sum(clf3.category_count_[i], axis=1)) + + # assert category 0 occurs 1x in the first class and 0x in the 2nd + # class + assert_array_equal(clf1.category_count_[0][0], np.array([1, 0])) + # assert category 1 occurs 0x in the first class and 2x in the 2nd + # class + assert_array_equal(clf1.category_count_[0][1], np.array([0, 2])) + + # assert category 0 occurs 0x in the first class and 1x in the 2nd + # class + assert_array_equal(clf1.category_count_[1][0], np.array([0, 1])) + # assert category 1 occurs 1x in the first class and 1x in the 2nd + # class + assert_array_equal(clf1.category_count_[1][1], np.array([1, 1])) + else: + assert_array_equal(clf1.feature_count_, clf3.feature_count_) + + +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, GaussianNB, + CategoricalNB]) def test_discretenb_pickle(cls): # Test picklability of discrete naive Bayes classifiers @@ -237,7 +271,8 @@ def test_discretenb_pickle(cls): assert_array_equal(y_pred, clf2.predict(X2)) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, GaussianNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, GaussianNB, + CategoricalNB]) def test_discretenb_input_check_fit(cls): # Test input checks for the fit method @@ -249,7 +284,7 @@ def test_discretenb_input_check_fit(cls): assert_raises(ValueError, clf.predict, X2[:, :-1]) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, CategoricalNB]) def test_discretenb_input_check_partial_fit(cls): # check shape consistency assert_raises(ValueError, cls().partial_fit, X2, y2[:-1], @@ -302,7 +337,7 @@ def test_discretenb_predict_proba(): assert_almost_equal(np.sum(np.exp(clf.intercept_)), 1) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, CategoricalNB]) def test_discretenb_uniform_prior(cls): # Test whether discrete NB classes fit a uniform prior # when fit_prior=False and class_prior=None @@ -314,7 +349,7 @@ def test_discretenb_uniform_prior(cls): assert_array_almost_equal(prior, np.array([.5, .5])) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, CategoricalNB]) def test_discretenb_provide_prior(cls): # Test whether discrete NB classes use provided prior @@ -329,7 +364,7 @@ def test_discretenb_provide_prior(cls): classes=[0, 1, 1]) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, CategoricalNB]) def test_discretenb_provide_prior_with_partial_fit(cls): # Test whether discrete NB classes use provided prior # when using partial_fit @@ -349,7 +384,7 @@ def test_discretenb_provide_prior_with_partial_fit(cls): clf_partial.class_log_prior_) -@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB]) +@pytest.mark.parametrize('cls', [BernoulliNB, MultinomialNB, CategoricalNB]) def test_discretenb_sample_weight_multiclass(cls): # check shape consistency for number of samples at fit time X = [ @@ -611,6 +646,52 @@ def test_cnb(): assert_array_almost_equal(clf.feature_log_prob_, normed_weights) +def test_categoricalnb(): + # Check the ability to predict the training set. + clf = CategoricalNB() + y_pred = clf.fit(X2, y2).predict(X2) + assert_array_equal(y_pred, y2) + + X3 = np.array([[1, 4], [2, 5]]) + y3 = np.array([1, 2]) + clf = CategoricalNB(alpha=1, fit_prior=False) + + clf.fit(X3, y3) + + # Check error is raised for X with negative entries + X = np.array([[0, -1]]) + y = np.array([1]) + error_msg = "X must not contain negative values." + assert_raise_message(ValueError, error_msg, clf.predict, X) + assert_raise_message(ValueError, error_msg, clf.fit, X, y) + + # Test alpha + X3_test = np.array([[2, 5]]) + # alpha=1 increases the count of all categories by one so the final + # probability for each category is not 50/50 but 1/3 to 2/3 + bayes_numerator = np.array([[1/3*1/3, 2/3*2/3]]) + bayes_denominator = bayes_numerator.sum() + assert_array_almost_equal(clf.predict_proba(X3_test), + bayes_numerator / bayes_denominator) + + # Assert category_count has counted all features + assert len(clf.category_count_) == X3.shape[1] + + # Check sample_weight + X = np.array([[0, 0], [0, 1], [0, 0], [1, 1]]) + y = np.array([1, 1, 2, 2]) + clf = CategoricalNB(alpha=1, fit_prior=False) + clf.fit(X, y) + assert_array_equal(clf.predict(np.array([[0, 0]])), np.array([1])) + + for factor in [1., 0.3, 5, 0.0001]: + X = np.array([[0, 0], [0, 1], [0, 0], [1, 1]]) + y = np.array([1, 1, 2, 2]) + sample_weight = np.array([1, 1, 10, 0.1]) * factor + clf = CategoricalNB(alpha=1, fit_prior=False) + clf.fit(X, y, sample_weight=sample_weight) + assert_array_equal(clf.predict(np.array([[0, 0]])), np.array([2])) + def test_alpha(): # Setting alpha=0 should not output nan results when p(x_i|y_j)=0 is a case @@ -628,6 +709,11 @@ def test_alpha(): prob = np.array([[2. / 3, 1. / 3], [0, 1]]) assert_array_almost_equal(nb.predict_proba(X), prob) + nb = CategoricalNB(alpha=0.) + assert_warns(UserWarning, nb.fit, X, y) + prob = np.array([[1., 0.], [0., 1.]]) + assert_array_almost_equal(nb.predict_proba(X), prob) + # Test sparse X X = scipy.sparse.csr_matrix(X) nb = BernoulliNB(alpha=0.) @@ -647,8 +733,10 @@ def test_alpha(): 'alpha should be > 0.') b_nb = BernoulliNB(alpha=-0.1) m_nb = MultinomialNB(alpha=-0.1) + c_nb = CategoricalNB(alpha=-0.1) assert_raise_message(ValueError, expected_msg, b_nb.fit, X, y) assert_raise_message(ValueError, expected_msg, m_nb.fit, X, y) + assert_raise_message(ValueError, expected_msg, c_nb.fit, X, y) b_nb = BernoulliNB(alpha=-0.1) m_nb = MultinomialNB(alpha=-0.1) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5a96a4260ceb9..40ccceb6a24c7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1587,7 +1587,8 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False): y_b = y_m[y_m != 2] X_b = X_m[y_m != 2] - if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: + if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB', + 'CategoricalNB']: X_m -= X_m.min() X_b -= X_b.min() diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 465acf48e8293..dc921989cf6ca 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -538,6 +538,7 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True, if not allow_nd and array.ndim >= 3: raise ValueError("Found array with dim %d. %s expected <= 2." % (array.ndim, estimator_name)) + if force_all_finite: _assert_all_finite(array, allow_nan=force_all_finite == 'allow-nan')