From 75ce0dece569cf4c4651b4f3160dbd9beb5194b1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 21 Feb 2022 04:31:14 -0500 Subject: [PATCH 001/102] Add abstract methods to _BaseDiscreteNB and minor corrections in comments --- sklearn/naive_bayes.py | 50 +++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 72eeb708e849d..3b0b0299257c5 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -53,8 +53,8 @@ def _joint_log_likelihood(self, X): I.e. ``log P(c) + log P(x|c)`` for all rows x of X, as an array-like of shape (n_samples, n_classes). - Input is passed to _joint_log_likelihood as-is by predict, - predict_proba and predict_log_proba. + Input is handed over to _joint_log_likelihood by predict, predict_proba + and predict_log_proba after being passed through _check_X. """ @abstractmethod @@ -140,7 +140,7 @@ class GaussianNB(_BaseNB): Parameters ---------- priors : array-like of shape (n_classes,) - Prior probabilities of the classes. If specified the priors are not + Prior probabilities of the classes. If specified, the priors are not adjusted according to the data. var_smoothing : float, default=1e-9 @@ -423,13 +423,13 @@ def _partial_fit(self, X, y, classes=None, _refit=False, sample_weight=None): # Take into account the priors if self.priors is not None: priors = np.asarray(self.priors) - # Check that the provide prior match the number of classes + # Check that the provided prior matches the number of classes if len(priors) != n_classes: raise ValueError("Number of priors must match number of classes.") # Check that the sum is 1 if not np.isclose(priors.sum(), 1.0): raise ValueError("The sum of the priors should be 1.") - # Check that the prior are non-negative + # Check that the priors are non-negative if (priors < 0).any(): raise ValueError("Priors must be non-negative.") self.class_prior_ = priors @@ -512,8 +512,37 @@ class _BaseDiscreteNB(_BaseNB): __init__ _joint_log_likelihood(X) as per _BaseNB + _update_feature_log_prob(alpha) + _count(X, Y) """ + @abstractmethod + def _count(self, X, Y): + """Update counts that are used to calculate probabilities. + + The counts make up a sufficient statistic extracted from the data. + Accordingly, this method is called each time ``fit`` or ``partial_fit`` + update the model. The number and composition of counts depend on a + concrete model, but ``self.class_count`` must be updated in any case. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + Y : array-like of shape (n_samples, n_classes) + Binarized class labels.""" + + @abstractmethod + def _update_feature_log_prob(self, alpha): + """Update feature log probabilities based on counts. + + This method is called each time ``fit`` or ``partial_fit`` update the + model. + + Parameters + ---------- + alpha : smoothing parameter. See :meth:`_check_alpha`.""" + def _check_X(self, X): """Validate X, used only in predict* methods.""" return self._validate_data(X, accept_sparse="csr", reset=False) @@ -523,6 +552,11 @@ def _check_X_y(self, X, y, reset=True): return self._validate_data(X, y, accept_sparse="csr", reset=reset) def _update_class_log_prior(self, class_prior=None): + """Update class log priors based `class_prior` (when provided) or class + counts. + + This method is called each time `fit` or `partial_fit` update the model. + """ n_classes = len(self.classes_) if class_prior is not None: if len(class_prior) != n_classes: @@ -733,7 +767,7 @@ class MultinomialNB(_BaseDiscreteNB): If false, a uniform prior will be used. class_prior : array-like of shape (n_classes,), default=None - Prior probabilities of the classes. If specified the priors are not + Prior probabilities of the classes. If specified, the priors are not adjusted according to the data. Attributes @@ -988,7 +1022,7 @@ class BernoulliNB(_BaseDiscreteNB): If false, a uniform prior will be used. class_prior : array-like of shape (n_classes,), default=None - Prior probabilities of the classes. If specified the priors are not + Prior probabilities of the classes. If specified, the priors are not adjusted according to the data. Attributes @@ -1136,7 +1170,7 @@ class CategoricalNB(_BaseDiscreteNB): If false, a uniform prior will be used. class_prior : array-like of shape (n_classes,), default=None - Prior probabilities of the classes. If specified the priors are not + Prior probabilities of the classes. If specified, the priors are not adjusted according to the data. min_categories : int or array-like of shape (n_features,), default=None From 62ebbe0fabe8475263835ae8f8b37836d41ac4ce Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 21 Feb 2022 23:32:17 -0500 Subject: [PATCH 002/102] Implemented ColumnwiseNB and tests --- sklearn/naive_bayes.py | 546 ++++++++++++++++++++++++++++++ sklearn/tests/test_naive_bayes.py | 430 +++++++++++++++++++++++ 2 files changed, 976 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 3b0b0299257c5..56a349c22848e 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -22,8 +22,10 @@ import numpy as np from scipy.special import logsumexp +from joblib import Parallel from .base import BaseEstimator, ClassifierMixin +from .base import clone from .preprocessing import binarize from .preprocessing import LabelBinarizer from .preprocessing import label_binarize @@ -32,6 +34,13 @@ from .utils.multiclass import _check_partial_fit_first_call from .utils.validation import check_is_fitted, check_non_negative from .utils.validation import _check_sample_weight +from .utils.validation import column_or_1d +from .utils.metaestimators import _BaseComposition +from .utils import _safe_indexing, _get_column_indices +from .utils import _print_elapsed_time +from .utils import Bunch +from .utils.fixes import delayed +from .compose._column_transformer import _is_empty_column_selection __all__ = [ @@ -40,6 +49,7 @@ "MultinomialNB", "ComplementNB", "CategoricalNB", + "ColumnwiseNB", ] @@ -1423,3 +1433,539 @@ def _joint_log_likelihood(self, X): jll += self.feature_log_prob_[i][:, indices].T total_ll = jll + self.class_log_prior_ return total_ll + + +def _fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): + """Call ``estimator.fit`` and print elapsed time message. + + See :func:`sklearn.pipeline._fit_one`. + """ + with _print_elapsed_time(message_clsname, message): + return estimator.fit(X, y, **fit_params) + + +def _partial_fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): + """Call ``estimator.partial_fit`` and print elapsed time message. + + See :func:`sklearn.pipeline._fit_one`. + """ + with _print_elapsed_time(message_clsname, message): + return estimator.partial_fit(X, y, **fit_params) + + +def _jll_one(estimator, X): + """Call ``estimator._joint_log_likelihood``. + + See :func:`sklearn.pipeline._transform_one`. + """ + return estimator._joint_log_likelihood(estimator._check_X(X)) + + +class ColumnwiseNB(_BaseNB, _BaseComposition): + """ + Column-wise Naive Bayes estimator. + + Parameters + ---------- + estimators : list of tuples + List of (name, estimatorNB, columns) tuples specifying the naive Bayes + estimators to be combined into a single naive Bayes meta-estimator. + + name : str + Name of the naive Bayes estimator. Like in Pipeline, FeatureUnion, + and ColumnTransformer, this allows the subestimator and its + parameters to be set using ``set_params`` and searched in grid + search. + estimatorNB : estimator + The estimator must support :term:`fit` or :term:`partial_fit`, + depending on how the meta-estimator is fitted. In addition, the + estimator must support ``_joint_log_likelihood`` method, which + takes :term:`X` of shape (n_samples, n_features) and returns a + numpy array of shape (n_samples, n_classes) containing joint + log-likelihoods, ``log P(x,c)`` for each sample point and class. + columns : str, array-like of str, int, array-like of int, \ + array-like of bool, slice or callable + Indexes the data on its second axis. Integers are interpreted as + positional columns, while strings can reference DataFrame columns + by name. A scalar string or int should be used where + ``estimatorNB`` expects X to be a 1d array-like (vector), + otherwise a 2d array will be passed to the transformer. + A callable is passed the input data `X` and can return any of the + above. To select multiple columns by name or dtype, you can use + :obj:`make_column_selector`. + + priors : array-like of shape (n_classes,) or str, default=None + Prior probabilities of classes. If unspecified, the priors are + calculated as relative frequencies of classes in the training data. + If str, the priors are taken from the estimator with the given name. + + n_jobs : int, default=None + Number of jobs to run in parallel. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See :term:`Glossary ` + for more details. + + verbose : bool, default=False + If True, the time elapsed while fitting each estimator will be + printed as it is completed. + + Attributes + ---------- + estimators_ : list of tuples + List of ``(name, fitted_estimatorNB, columns)`` tuples, which follow + the order of `estimators`. ``fitted_estimatorNB`` is a fitted naive + Bayes estimator, except when ``columns`` presents an empty selection of + columns, in which case it is the original unfitted ``estimatorNB``. + Here ``columns`` is converted to a list of column indices, if the + original specification in `estimators` was a callable. + + named_estimators_ : :class:`~sklearn.utils.Bunch` + Read-only attribute to access any subestimator by given name. + Keys are estimator names and values are the fitted estimators, except + when a subestimator does not require fitting (i.e., when ``columns`` is + an empty set of indices). + + class_prior_ : ndarray of shape (n_classes,) + Prior probabilities of classes used in the naive Bayes meta-estimator, + which are calculated as relative frequencies, extracted from + subestimators, or provided, according to the value of `priors` + at initialization. + + class_count_ : ndarray of shape (n_classes,) + Number of samples encountered for each class during fitting. This + value is weighted by the sample weight when provided. + + n_classes_ : int + The number of classes known to the naive Bayes classifier, `n_classes`. + + classes_ : ndarray of shape (n_classes,) + Class labels known to the classifier. + + feature_names_in_ : ndarray of shape (n_features_in_,) + Names of features seen during :term:`fit`. Only defined if `X` has + feature names that are all strings. + + Notes + ----- + ColumnwiseNB combines multiple naive Bayes estimators by expressing the + overall joint probability ``P(x,y)`` through ``P(x_i|y)``, the joint + probabilities of the subestimators: + ``Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y)``, + where ``N`` denotes ``n_estimators``, the number of estimators. + It is implicitly assumed that the class log priors are finite and agree + between the estimators and the subestimator: + ``- inf < Log P(y) = Log P(y|1) = ... = Log P(y|N)``. + The meta-estimators does not check if this condition holds. Meaningless + results, including ``NaN``, may be produced by ColumnwiseNB if the class + priors differ or contain a zero probability. + """ + def _log_message(self, name, idx, total): + if not self.verbose: + return None + return "(%d of %d) Processing %s" % (idx, total, name) + + def __init__(self, estimators, priors=None, n_jobs=None, verbose=False): + self.estimators = estimators + self.priors = priors + self.n_jobs = n_jobs + self.verbose = verbose + + def _check_X(self, X): + """Validate X, used only in predict* methods.""" + # The meta-estimator checks for feature names only. Other checks + # and conversion to numpy array are performed by subestimators. + # It is important that X is not modified by the meta-estimator, and X's + # columns are passed to an estimator as they are. Note that estimators + # may modify (a copy of) X. E.g., BernoulliNB._check_X binarises the + # input. + self._check_feature_names(X, reset=False) + return X + + def _joint_log_likelihood(self, X): + """Calculate the meta-estimator's joint log likelihood ``log P(x,c)``. + """ + # Because data must follow the same path as it would in subestimators, + # _jll_one(estimatorNB, X) passes it through estimatorNB._check_X to + # match the implementation of _BaseNB.predict_log_proba. + # Changes therein must be reflected in _jll_one or here. + estimators = self._iter(fitted=True, replace_strings=True) + all_jlls = Parallel(n_jobs=self.n_jobs)( + delayed(_jll_one)( + estimator=estimatorNB, + X=_safe_indexing(X, cols, axis=1) + ) + for (_, estimatorNB, cols) in estimators + ) + n_estimators = len(all_jlls) + log_prior = np.log(self.class_prior_) + return np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior + + def _validate_estimators(self, check_partial=False): + # Check if estimators have fit/partial_fit and jll methods + # Validate estimator names via _BaseComposition._validate_names(self, names) + if not self.estimators: + raise ValueError( + "A list of naive Bayes estimators must be provided " + "in the form [(name, estimatorNB, columns), ... ]." + ) + names, estimators, _ = zip(*self.estimators) + for e in estimators: + if (not check_partial) and ( + not (hasattr(e, "fit") and hasattr(e, "_joint_log_likelihood")) + ): + raise TypeError( + "Estimators must be naive Bayes estimators implementing " + "`fit` and `_joint_log_likelihood` methods." + ) + if check_partial and ( + not (hasattr(e, "partial_fit") + and hasattr(e, "_joint_log_likelihood")) + ): + raise TypeError( + "Estimators must be Naive Bayes estimators implementing " + "`partial_fit` and `_joint_log_likelihood` methods." + ) + self._validate_names(names) + + def _validate_column_callables(self, X): + """ + Convert callable column specifications and store into self._columns. + + Empty-set columns do not enjoy any special treatment. + """ + # Almost a verbatim copy of ColumnTransformer._validate_column_callables(). + # Consider refactoring in the future. + # Unlike ColumnTransformer, this estimator does not need to output a + # dataframe or validate a the remainder, so _estimator_to_input_indices + # is not really needed, but retained for consistency with + # ColumnTransformer code. + all_columns = [] + estimator_to_input_indices = {} + for name, _, columns in self.estimators: + if callable(columns): + columns = columns(X) + all_columns.append(columns) + estimator_to_input_indices[name] = _get_column_indices(X, columns) + self._columns = all_columns + self._estimator_to_input_indices = estimator_to_input_indices + + @property + def named_estimators_(self): + """Access the fitted naive Bayes subestimators by name. + + Read-only attribute to access any estimators by given name. + Keys are estimators names and values are the fitted estimator + objects. + """ + # Almost a verbatim copy of ColumnTransformer.named_transformers_ + # Use Bunch object to improve autocomplete + return Bunch(**{name: e for name, e, _ in self.estimators_}) + + def _iter(self, *, fitted=False, replace_strings=False): + """Generate ``(name, estimatorNB, columns)`` tuples. + + This is a private method, similar to ColumnTransformer._iter. + Must not be called before _validate_column_callables. + + Parameters + ---------- + fitted : bool, default=False + If False, returns tuples from self.estimators (user-specified), but + callable columns are replaced with a list column names or indices. + If True, returns tuples from self.estimators_ (fitted), where + columns are processed as well. + + replace_strings : bool, default=False + If True, omits the estimators that do not require fitting, i.e those + with empty-set columns. The name `replace_strings` is a relic of + ColumnTransformer implementation, where `passthrough` and `drop` + required replacement and omission, respectively. + + Yields + ------ + tuple + of the form ``(name, estimatorNB, columns)``. + + Notes + ----- + Loop through estimators from this generator with the following + parameters, depending on the purpose: + + self._iter(fitted=False, replace_strings=True) : + fit, 1st partial_fit + self._iter(fitted=True, replace_strings=True) : + further partial_fit, predict + self._iter(fitted=False, replace_strings=False) : + update fitted estimators. Note that special treatment is required + for unfitted estimators (those with empty-set columns)! + self._iter(fitted=True, replace_strings=False) : + not used here. The usecase in ColumnTransformer would be sorting + out the transformed output and its column names. + do not use in : + a Bunch accessor named_estimators_; + input validation _validate_estimators, _validate_column_callables; + parameter management: get_params_, set_params_, _estimators. + """ + if fitted: + for (name, estimator, cols) in self.estimators_: + if replace_strings and _is_empty_column_selection(cols): + continue + else: + yield (name, estimator, cols) + else: # fitted=False + for (name, estimator, _), cols in (zip(self.estimators, + self._columns)): + if replace_strings and _is_empty_column_selection(cols): + continue + else: + yield (name, estimator, cols) + + def _update_class_prior(self): + """Update class prior after most of the fitting as done.""" + if self.priors is None: # calculcate empirical prior from counts + priors = self.class_count_ / self.class_count_.sum() + elif isinstance(self.priors, str): # extract prior from estimator + name = self.priors + e = self.named_estimators_[name] + if getattr(e, 'class_prior_', None) is not None: + priors = e.class_prior_ + elif getattr(e, 'class_log_prior_', None) is not None: + priors = np.exp(e.class_log_prior_) + else: + raise AttributeError( + f"Unable to extract class prior from estimator {name}, as " + "it does not have class_prior_ or class_log_prior_ " + "attributes.") + else: # check the provided prior + priors = np.asarray(self.priors) + # Check the prior in any case. + if len(priors) != self.n_classes_: + raise ValueError("Number of priors must match number of classes.") + if not np.isclose(priors.sum(), 1.0): + raise ValueError("The sum of the priors should be 1.") + if (priors < 0).any(): + raise ValueError("Priors must be non-negative.") + self.class_prior_ = priors + + def _update_fitted_estimators(self, fitted_estimators): + """Update tuples in self.estimators_ with fitted_estimators provided. + + Callable columns are replaced with sets of actual str or int indices. + Estimators that don't require fitting are passed as they were, + without cloning. + """ + estimators_ = [] + fitted_estimators = iter(fitted_estimators) + + for name, estimatorNB, cols in self._iter(): + if not _is_empty_column_selection(cols): + updated_estimatorNB = next(fitted_estimators) + else: # don't advance fitted_estimators; use original + updated_estimatorNB = estimatorNB + estimators_.append((name, updated_estimatorNB, cols)) + self.estimators_ = estimators_ + + def fit(self, X, y, sample_weight=None): + """Fit the naive Bayes meta-estimator. + + Calls `fit` of each subestimator ``estimatorNB``. Only a corresponding + subset of columns of `X` is passed to each subestimator; `sample_weight` + and `y` are passed to the subestimators as they are. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples + and `n_features` is the number of features. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), default=None + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + Returns the instance itself. + """ + self._check_feature_names(X, reset=True) + # TODO: Consider overriding BaseEstimator._check_feature_names + # Currently, when X has all str feature names, all features are + # registered in self.feature_names_in no matter if they are used or not. + self._validate_estimators() + self._validate_column_callables(X) + # Consistency checks for X, y are delegated to subestimators + + # Subestimators get original sample_weight. This is for class counts: + if sample_weight is not None: + weights = _check_sample_weight(sample_weight, X=y, copy=True) + + # We would use sklearn.utils.multiclass.class_distribution, but it does + # not return class_count, which we want as well. + if sample_weight is None: + self.classes_, self.class_count_ = np.unique(column_or_1d(y), + return_counts=True) + else: + self.classes_ = np.unique(column_or_1d(y)) + counts = np.zeros(len(self.classes_)) + for i, c in enumerate(self.classes_): + counts[i] = (weights * (column_or_1d(y) == c)).sum() + self.class_count_ = counts + self.n_classes_ = len(self.classes_) + self._update_class_prior() + + estimators = list(self._iter(fitted=False, replace_strings=True)) + fitted_estimators = Parallel(n_jobs=self.n_jobs)( + delayed(_fit_one)( + estimator=clone(estimatorNB), + X=_safe_indexing(X, cols, axis=1), + y=y, + message_clsname="ColumnwiseNB", + message=self._log_message(name, idx, len(estimators)), + sample_weight=sample_weight + ) + for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + ) + self._update_fitted_estimators(fitted_estimators) + return self + + def partial_fit(self, X, y, classes=None, sample_weight=None): + """Fit incrementally the naive Bayes meta-estimator on a batch of samples. + + Calls `partial_fit` of each subestimator. Only a corresponding + subset of columns of `X` is passed to each subestimator. `classes`, + `sample_weight` and 'y' are passed to the subestimators as they are. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : array-like of shape (n_samples,) + Target values. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + Returns the instance itself. + """ + first_call = not hasattr(self, "classes_") + if first_call: + self._check_feature_names(X, reset=True) + self._validate_estimators(check_partial=True) + self._validate_column_callables(X) + else: + self._check_feature_names(X, reset=False) + # Consistency checks for X, y are delegated to subestimators + + # Subestimators get original sample_weight. This is for class counts: + if sample_weight is not None: + weights = _check_sample_weight(sample_weight, X=y, copy=True) + + # Subestimators should've checked classes. We set classes_ for counts + # and so that first_call becomes False at next partial_fit call. + _check_partial_fit_first_call(self, classes) + + # We don't use sklearn.utils.multiclass.class_distribution, because it + # neither returns class_count, nor is suitable for partial_fit. + if sample_weight is None: + counts = np.zeros(len(self.classes_)) + for i, c in enumerate(self.classes_): + counts[i] = (column_or_1d(y) == c).sum() + else: + counts = np.zeros(len(self.classes_)) + for i, c in enumerate(self.classes_): + counts[i] = (weights * (column_or_1d(y) == c)).sum() + + if first_call: + self.n_classes_ = len(self.classes_) + self.class_count_ = counts + else: + self.class_count_ += counts + self._update_class_prior() + + estimators = list(self._iter(fitted=not first_call, replace_strings=True)) + fitted_estimators = Parallel(n_jobs=self.n_jobs)( + delayed(_partial_fit_one)( + estimator=clone(estimatorNB) if first_call else estimatorNB, + X=_safe_indexing(X, cols, axis=1), + y=y, + message_clsname="ColumnwiseNB", + message=self._log_message(name, idx, len(estimators)), + classes=classes, + sample_weight=sample_weight + ) + for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + ) + self._update_fitted_estimators(fitted_estimators) + return self + + @property + def _estimators(self): + """Internal list of subestimators. + + This is for the implementation of get_params via BaseComposition._get_params, + which expects lists of tuples of len 2. + """ + # Implemented in the image and likeness of ColumnTranformer._transformers + return [(name, e) for name, e, _ in self.estimators] + + @_estimators.setter + def _estimators(self, value): + # Implemented in the image and likeness of ColumnTranformer._transformers + # TODO: Is renaming or changing the order legal? Swap `name` and `_`? + self.estimators = [ + (name, e, col) + for ((name, e), (_, _, col)) in zip(value, self.estimators) + ] + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Returns the parameters listed in the constructor as well as the + subestimators contained within the `estimators` of the `ColumnwiseNB` + instance. + + Parameters + ---------- + deep : bool, default=True + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + # Implemented in the image and likeness of ColumnTranformer.get_params + return self._get_params("_estimators", deep=deep) + + def set_params(self, **kwargs): + """Set the parameters of this estimator. + + Valid parameter keys can be listed with ``get_params()``. Note that you + can directly set the parameters of the estimators contained in + `estimators` of `ColumnwiseNB`. + + Parameters + ---------- + **kwargs : dict + Estimator parameters. + + Returns + ------- + self : ColumnwiseNB + This estimator. + """ + # Implemented in the image and likeness of ColumnTranformer.set_params + self._set_params("_estimators", **kwargs) + return self diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 36b1c29b36c1d..ae4d00c2d0c27 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -3,6 +3,8 @@ import numpy as np import scipy.sparse import pytest +import pandas as pd +from itertools import chain from sklearn.datasets import load_digits, load_iris @@ -18,6 +20,12 @@ from sklearn.naive_bayes import MultinomialNB, ComplementNB from sklearn.naive_bayes import CategoricalNB +from sklearn.base import BaseEstimator +from sklearn.base import clone +from sklearn.compose import make_column_selector +from sklearn.exceptions import DataConversionWarning +from sklearn.naive_bayes import ColumnwiseNB + DISCRETE_NAIVE_BAYES_CLASSES = [BernoulliNB, CategoricalNB, ComplementNB, MultinomialNB] ALL_NAIVE_BAYES_CLASSES = DISCRETE_NAIVE_BAYES_CLASSES + [GaussianNB] @@ -26,6 +34,10 @@ X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) y = np.array([1, 1, 1, 2, 2, 2]) +# Same as above, but a dataframe +Xdf = pd.DataFrame(data=X, columns=['col0', 'col1']) +ydf = pd.DataFrame({'target': y}) + # A bit more random tests rng = np.random.RandomState(0) X1 = rng.normal(size=(10, 3)) @@ -945,3 +957,421 @@ def test_n_features_deprecation(Estimator): with pytest.warns(FutureWarning, match="`n_features_` was deprecated"): est.n_features_ + + +def test_cwnb_union(): + # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [0]), + ('g2', GaussianNB(), [1])]) + clf2 = GaussianNB() + clf1.fit(X, y) + clf2.fit(X, y) + assert_array_almost_equal(clf1.predict(X), clf2.predict(X), 8) + assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), + clf2.predict_log_proba(X), 8) + + # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [0]), + ('b2', BernoulliNB(), [1, 2])]) + clf2 = BernoulliNB() + clf1.fit(X1, y1) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), + clf2.predict_log_proba(X1), 8) + assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + + # A union of BernoulliNB's yields the same prediction a single BernoulliNB + # (partial_fit) + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [0]), + ('b2', BernoulliNB(), [1, 2])]) + clf2 = BernoulliNB() + clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) + clf1.partial_fit(X1[5:], y1[5:]) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), + clf2.predict_log_proba(X1), 8) + assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + + # A union of several different NB's is permutation-invariant + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [3]), + ('g1', GaussianNB(), [0]), + ('m1', MultinomialNB(), [0, 2]), + ('b2', BernoulliNB(), [1]) + ]) + # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers + clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [3]), + ('g1', GaussianNB(), [1]), + ('m1', MultinomialNB(), [1, 0]), + ('b2', BernoulliNB(), [2]) + ]) + clf1.fit(X2[:, [0, 1, 2, 3, 4]], y2) # (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) + clf2.fit(X2[:, [2, 0, 1, 3, 4]], y2) # (0, 1, 2, 3, 4) <- (2, 0, 1, 3, 4) + assert_array_almost_equal(clf1.predict_proba(X2), + clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), 8) + assert_array_almost_equal(clf1.predict_log_proba(X2), + clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), 8) + assert_array_almost_equal(clf1.predict(X2), + clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8) + + +def test_cwnb_estimators_1(): + # Subestimators spec: cols can be lists of int or lists of str, if DataFrame + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), ['col1']), + ('g2', GaussianNB(), ['col0', 'col1'])]) + clf1.fit(X, y) + clf2.fit(Xdf, y) + assert_array_almost_equal(clf1.predict_log_proba(X), + clf2.predict_log_proba(Xdf), 8) + msg = "A column-vector y was passed when a 1d array was expected" + with pytest.warns(DataConversionWarning, match=msg): + clf2.fit(Xdf, ydf) + assert_array_almost_equal(clf1.predict_log_proba(X), + clf2.predict_log_proba(Xdf), 8) + + # Subestimators spec: repeated col ints have the same effect as repeating data + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1, 1]), + ('b1', BernoulliNB(), [0, 0, 1, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [0, 1]), + ('b1', BernoulliNB(), [2, 3, 4, 5])]) + clf1.fit(X1, y1) + clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) + assert_array_almost_equal(clf1.predict_log_proba(X1), + clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8) + + # Subestimators spec: empty cols have the same effect as an absent estimator + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), []), + ('g3', GaussianNB(), [0, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g3', GaussianNB(), [0, 1])]) + clf1.fit(X1, y1) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_log_proba(X1), + clf2.predict_log_proba(X1), 8) + # Empty-columns estimators are passed to estimators_ and the numbers match + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 + # No cloning of the empty-columns estimators took place: + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_['g2']) + + # Subestimators spec: empty cols have the same effect as an absent estimator + # when callable columns produce the empty set. + + select_none = make_column_selector(pattern="qwerasdf") + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), select_none), + ('g3', GaussianNB(), [0, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g3', GaussianNB(), [0, 1])]) + clf1.fit(Xdf, y) + clf2.fit(Xdf, y) + assert_array_almost_equal(clf1.predict_log_proba(Xdf), + clf2.predict_log_proba(Xdf), 8) + # Empty-columns estimators are passed to estimators_ and the numbers match + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 + # No cloning of the empty-columns estimators took place: + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_['g2']) + + # Subestimators spec: test callable columns + select_int = make_column_selector(dtype_include=np.int_) + select_float = make_column_selector(dtype_include=np.float_) + Xdf2 = Xdf + Xdf2['col3'] = np.exp(Xdf['col0']) - 0.5 * Xdf['col1'] + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), ['col3']), + ('m1', BernoulliNB(), ['col0', 'col1'])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), select_float), + ('g2', BernoulliNB(), select_int)]) + clf1.fit(Xdf, y) + clf2.fit(Xdf, y) + assert_array_almost_equal(clf1.predict_log_proba(Xdf), + clf2.predict_log_proba(Xdf), 8) + + # Subestimators spec: error on repeated names + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g1', GaussianNB(), [0, 1])]) + msg = "Names provided are not unique" + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB(estimators=[['g1', GaussianNB(), [1]], + ['g2', GaussianNB(), [0, 1]]]) + clf1.fit(X, y) + + +def test_cwnb_estimators_2(): + # Subestimators spec: error when some don't support _joint_log_likelihood + class notNB(BaseEstimator): + def __init__(self): pass + def fit(self, X, y): pass + def partial_fit(self, X, y): pass + # def _joint_log_likelihood(self, X): pass + def predict(self, X): pass + clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], + ['g2', GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.partial_fit(X, y) + + # Subestimators spec: error when some don't support fit + class notNB(BaseEstimator): + def __init__(self): pass + # def fit(self, X, y): pass + def partial_fit(self, X, y): pass + def _joint_log_likelihood(self, X): pass + def predict(self, X): pass + clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], + ['g2', GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.fit(X, y) + + # Subestimators spec: error when some don't support partial_fit + class notNB(BaseEstimator): + def __init__(self): pass + def fit(self, X, y): pass + # def partial_fit(self, X, y): pass + def _joint_log_likelihood(self, X): pass + def predict(self, X): pass + clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], + ['g2', GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.partial_fit(X, y) + + +def test_cwnb_prior(): + # prior spec: error when negative, sum!=1 or bad length + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])], + priors=np.array([-0.25, 1.25])) + msg = "Priors must be non-negative." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])], + priors=np.array([0.25, .7])) + msg = "The sum of the priors should be 1." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])], + priors=np.array([0.25, 0.25, 0.25, 0.25])) + msg = "Number of priors must match number of classes." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + # prior spec: specified prior equals calculated and subestimators' priors + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])], + priors=np.array([.5, .5])) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])]) + clf1.fit(X, y) + clf2.fit(X, y) + assert clf2.priors is None + assert_array_almost_equal(clf1.class_prior_, + clf1.named_estimators_['g1'].class_prior_, + 8) + assert_array_almost_equal(clf1.class_prior_, + clf1.named_estimators_['g2'].class_prior_, + 8) + assert_array_almost_equal(clf1.class_prior_, clf2.class_prior_, 8) + + +def test_cwnb_zero_prior(): + # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator + clf1 = ColumnwiseNB(estimators=[ + ('g1', GaussianNB(), [1, 3, 5]), + ('g2', GaussianNB(priors=np.array([.5, 0, .5])), [0, 1]) + ]) + clf1.fit(X2, y2) + msg = "divide by zero encountered in log" + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(X2)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + Xt = rng.randint(5, size=(6, 100)) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(Xt)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + + # P(y)=0 in the meta-estimator, as well as class priors that differ across + # subestimators may produce meaningless results, including NaNs. This case + # is not tested here. + + # P(y)=0 in two subestimators results in P(y|x)=0 of meta-estimator + clf1 = ColumnwiseNB(estimators=[ + ('g1', GaussianNB(priors=np.array([.6, 0, .4])), [1, 3, 5]), + ('g2', GaussianNB(priors=np.array([.5, 0, .5])), [0, 1]) + ]) + clf1.fit(X2, y2) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(X2)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + Xt = rng.randint(5, size=(6, 100)) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(Xt)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + + +def test_cwnb_sample_weight(): + # weights in fit have no effect if all ones + weights = [1, 1, 1, 1, 1, 1] + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), + ('g2', GaussianNB(), [0, 1])]) + clf1.fit(X, y, sample_weight=weights) + clf2.fit(X, y) + assert_array_almost_equal(clf1._joint_log_likelihood(X), + clf2._joint_log_likelihood(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), + clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), + clf2.predict(X)) + + # weights in partial_fit have no effect if all ones + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) + clf2.partial_fit(X2, y2, classes=np.unique(y2)) + assert_array_almost_equal(clf1._joint_log_likelihood(X2), + clf2._joint_log_likelihood(X2), 8) + assert_array_almost_equal(clf1.predict_log_proba(X2), + clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), + clf2.predict(X2)) + + # weights in fit have the same effect as repeating data + weights = [1, 2, 3, 1, 4, 2] + idx = list(chain(*([i] * w for i, w in enumerate(weights)))) + # var_smoothing=0.0 is for maximum precision in dealing with a small sample + clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(var_smoothing=0.0), [1]), + ('g2', GaussianNB(var_smoothing=0.0), [0, 1])]) + clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(var_smoothing=0.0), [1]), + ('g2', GaussianNB(var_smoothing=0.0), [0, 1])]) + clf1.fit(X, y, sample_weight=weights) + clf2.fit(X[idx], y[idx]) + assert_array_almost_equal(clf1._joint_log_likelihood(X), + clf2._joint_log_likelihood(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), + clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), + clf2.predict(X), 8) + for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + # weights in partial_fit have the same effect as repeating data + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) + clf2.partial_fit(X2[idx], y2[idx], classes=np.unique(y2)) + assert_array_equal(clf1._joint_log_likelihood(X2), + clf2._joint_log_likelihood(X2)) + assert_array_equal(clf1.predict_log_proba(X2), + clf2.predict_log_proba(X2)) + assert_array_equal(clf1.predict(X2), + clf2.predict(X2)) + for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + +def test_cwnb_partial_fit(): + # partial_fit: consecutive calls yield the same prediction as a single call + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf1.partial_fit(X2, y2, classes=np.unique(y2)) + clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) + clf2.partial_fit(X2[4:], y2[4:]) + assert_array_almost_equal(clf1._joint_log_likelihood(X2), + clf2._joint_log_likelihood(X2), 8) + assert_array_almost_equal(clf1.predict_log_proba(X2), + clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), + clf2.predict(X2)) + for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + # partial_fit: error when classes are not provided at the first call + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + msg = ".lasses must be passed on the first call to partial_fit" + with pytest.raises(ValueError, match=msg): + clf1.partial_fit(X2, y2) + + +def test_cwnb_consistency(): + # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3])]) + clf1.fit(X2, y2) + for se in clf1.named_estimators_: + assert_array_almost_equal(clf1.class_count_, + clf1.named_estimators_[se].class_count_, 8) + assert_array_almost_equal(clf1.classes_, + clf1.named_estimators_[se].classes_, 8) + assert_array_almost_equal(np.log(clf1.class_prior_), + clf1.named_estimators_[se].class_log_prior_, 8) + + +def test_cwnb_params(): + # Can get and set subestimators' parameters through name__paramname + # clone() works on ColumnwiseNB + clf1 = ColumnwiseNB(estimators=[ + ('b1', BernoulliNB(alpha=.2, binarize=2), [1]), + ('m1', MultinomialNB(class_prior=[.2, .2, .6]), [0, 2, 3]) + ]) + clf1.fit(X2, y2) + p = clf1.get_params(deep=True) + assert p['b1__alpha'] == .2 + assert p['b1__binarize'] == 2 + assert p['m1__class_prior'] == [.2, .2, .6] + clf1.set_params(b1__alpha=123, m1__class_prior=[.3, .3, .4]) + assert clf1.estimators[0][1].alpha == 123 + assert_array_equal(clf1.estimators[1][1].class_prior, [.3, .3, .4]) + # After cloning and fitting, we can check through named_estimators, which + # maps to fitted estimators_: + clf2 = clone(clf1).fit(X2, y2) + assert clf2.named_estimators_['b1'].alpha == 123 + assert_array_equal(clf2.named_estimators_['m1'].class_prior, [.3, .3, .4]) + assert (id(clf2.named_estimators_['b1']) != + id(clf1.named_estimators_['b1'])) + + +def test_cwnb_n_jobs(): + # n_jobs: same results wether with it or without + clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('b2', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3]), + ('m3', MultinomialNB(), slice(10, None))], + n_jobs=4) + clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), + ('b2', BernoulliNB(binarize=2), [1]), + ('m1', MultinomialNB(), [0, 2, 3]), + ('m3', MultinomialNB(), slice(10, None))]) + clf1.partial_fit(X2, y2, classes=np.unique(y2)) + clf2.partial_fit(X2, y2, classes=np.unique(y2)) + + assert_array_almost_equal(clf1._joint_log_likelihood(X2), + clf2._joint_log_likelihood(X2), 8) + assert_array_almost_equal(clf1.predict_log_proba(X2), + clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), + clf2.predict(X2)) From 4ca9ac5c5a6ec41b894afb164ace785005647be3 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 21 Feb 2022 23:46:13 -0500 Subject: [PATCH 003/102] Added my name to module authors. --- sklearn/naive_bayes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 56a349c22848e..5aa9b08a170f1 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -13,6 +13,7 @@ # Lars Buitinck # Jan Hendrik Metzen # (parts based on earlier work by Mathieu Blondel) +# Andrey V. Melnik # # License: BSD 3 clause import warnings From d7a9bf4519acca5ed851c377b41239c57b98c660 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 02:54:10 -0500 Subject: [PATCH 004/102] ColumnwiseNB docstring correction, See Also, Example. Added example to the tests --- sklearn/naive_bayes.py | 28 +++++++++++++++++++++++++++- sklearn/tests/test_naive_bayes.py | 14 ++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 5aa9b08a170f1..a8bfa0bf70860 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1546,10 +1546,19 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Names of features seen during :term:`fit`. Only defined if `X` has feature names that are all strings. + See Also + -------- + BernoulliNB : Naive Bayes classifier for multivariate Bernoulli models. + CategoricalNB : Naive Bayes classifier for categorical features. + ComplementNB : Complement Naive Bayes classifier. + MultinomialNB : Naive Bayes classifier for multinomial models. + GaussianNB : Gaussian Naive Bayes. + ColumnTransformer : Applies transformers to columns. + Notes ----- ColumnwiseNB combines multiple naive Bayes estimators by expressing the - overall joint probability ``P(x,y)`` through ``P(x_i|y)``, the joint + overall joint probability ``P(x,y)`` through ``P(x_i,y)``, the joint probabilities of the subestimators: ``Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y)``, where ``N`` denotes ``n_estimators``, the number of estimators. @@ -1559,6 +1568,23 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): The meta-estimators does not check if this condition holds. Meaningless results, including ``NaN``, may be produced by ColumnwiseNB if the class priors differ or contain a zero probability. + + Examples + -------- + >>> import numpy as np + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) + >>> y = np.array([0, 0, 1, 1, 2, 2]) + >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB + >>> clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ... ('mnb2', MultinomialNB(), [3, 4]), + ... ('gnb1', GaussianNB(), [5])]) + >>> clf.fit(X, y) + ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ('mnb2', MultinomialNB(), [3, 4]), + ('gnb1', GaussianNB(), [5])]) + >>> print(clf.predict(X)) + [0 0 1 0 2 2] """ def _log_message(self, name, idx, total): if not self.verbose: diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index ae4d00c2d0c27..11cc852518ca0 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1375,3 +1375,17 @@ def test_cwnb_n_jobs(): clf2.predict_log_proba(X2), 8) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + + +def test_cwnb_example(): + # Test the Example from ColumnwiseNB docstring in naive_bayes.py + import numpy as np + rng = np.random.RandomState(1) + X = rng.randint(5, size=(6, 100)) + y = np.array([0, 0, 1, 1, 2, 2]) + from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB + clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ('mnb2', MultinomialNB(), [3, 4]), + ('gnb1', GaussianNB(), [5])]) + clf.fit(X, y) + clf.predict(X) From 2dfcd4093dcc8178590d8b69fc027dd86e1834e6 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 04:10:30 -0500 Subject: [PATCH 005/102] black formatting compliance. --- sklearn/naive_bayes.py | 41 ++- sklearn/tests/test_naive_bayes.py | 533 ++++++++++++++++++------------ 2 files changed, 338 insertions(+), 236 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index a8bfa0bf70860..8e13b9d308771 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1586,6 +1586,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): >>> print(clf.predict(X)) [0 0 1 0 2 2] """ + def _log_message(self, name, idx, total): if not self.verbose: return None @@ -1617,12 +1618,9 @@ def _joint_log_likelihood(self, X): # Changes therein must be reflected in _jll_one or here. estimators = self._iter(fitted=True, replace_strings=True) all_jlls = Parallel(n_jobs=self.n_jobs)( - delayed(_jll_one)( - estimator=estimatorNB, - X=_safe_indexing(X, cols, axis=1) - ) + delayed(_jll_one)(estimator=estimatorNB, X=_safe_indexing(X, cols, axis=1)) for (_, estimatorNB, cols) in estimators - ) + ) n_estimators = len(all_jlls) log_prior = np.log(self.class_prior_) return np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior @@ -1645,8 +1643,7 @@ def _validate_estimators(self, check_partial=False): "`fit` and `_joint_log_likelihood` methods." ) if check_partial and ( - not (hasattr(e, "partial_fit") - and hasattr(e, "_joint_log_likelihood")) + not (hasattr(e, "partial_fit") and hasattr(e, "_joint_log_likelihood")) ): raise TypeError( "Estimators must be Naive Bayes estimators implementing " @@ -1740,8 +1737,7 @@ def _iter(self, *, fitted=False, replace_strings=False): else: yield (name, estimator, cols) else: # fitted=False - for (name, estimator, _), cols in (zip(self.estimators, - self._columns)): + for (name, estimator, _), cols in zip(self.estimators, self._columns): if replace_strings and _is_empty_column_selection(cols): continue else: @@ -1754,15 +1750,16 @@ def _update_class_prior(self): elif isinstance(self.priors, str): # extract prior from estimator name = self.priors e = self.named_estimators_[name] - if getattr(e, 'class_prior_', None) is not None: + if getattr(e, "class_prior_", None) is not None: priors = e.class_prior_ - elif getattr(e, 'class_log_prior_', None) is not None: + elif getattr(e, "class_log_prior_", None) is not None: priors = np.exp(e.class_log_prior_) else: raise AttributeError( f"Unable to extract class prior from estimator {name}, as " "it does not have class_prior_ or class_log_prior_ " - "attributes.") + "attributes." + ) else: # check the provided prior priors = np.asarray(self.priors) # Check the prior in any case. @@ -1829,8 +1826,9 @@ def fit(self, X, y, sample_weight=None): # We would use sklearn.utils.multiclass.class_distribution, but it does # not return class_count, which we want as well. if sample_weight is None: - self.classes_, self.class_count_ = np.unique(column_or_1d(y), - return_counts=True) + self.classes_, self.class_count_ = np.unique( + column_or_1d(y), return_counts=True + ) else: self.classes_ = np.unique(column_or_1d(y)) counts = np.zeros(len(self.classes_)) @@ -1848,10 +1846,10 @@ def fit(self, X, y, sample_weight=None): y=y, message_clsname="ColumnwiseNB", message=self._log_message(name, idx, len(estimators)), - sample_weight=sample_weight - ) - for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + sample_weight=sample_weight, ) + for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + ) self._update_fitted_estimators(fitted_estimators) return self @@ -1929,10 +1927,10 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): message_clsname="ColumnwiseNB", message=self._log_message(name, idx, len(estimators)), classes=classes, - sample_weight=sample_weight - ) - for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + sample_weight=sample_weight, ) + for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + ) self._update_fitted_estimators(fitted_estimators) return self @@ -1951,8 +1949,7 @@ def _estimators(self, value): # Implemented in the image and likeness of ColumnTranformer._transformers # TODO: Is renaming or changing the order legal? Swap `name` and `_`? self.estimators = [ - (name, e, col) - for ((name, e), (_, _, col)) in zip(value, self.estimators) + (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimators) ] def get_params(self, deep=True): diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 11cc852518ca0..edac66fa80f6f 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -35,8 +35,8 @@ y = np.array([1, 1, 1, 2, 2, 2]) # Same as above, but a dataframe -Xdf = pd.DataFrame(data=X, columns=['col0', 'col1']) -ydf = pd.DataFrame({'target': y}) +Xdf = pd.DataFrame(data=X, columns=["col0", "col1"]) +ydf = pd.DataFrame({"target": y}) # A bit more random tests rng = np.random.RandomState(0) @@ -415,9 +415,7 @@ def test_discretenb_sample_weight_multiclass(DiscreteNaiveBayes): @pytest.mark.parametrize("use_partial_fit", [False, True]) @pytest.mark.parametrize("train_on_single_class_y", [False, True]) def test_discretenb_degenerate_one_class_case( - DiscreteNaiveBayes, - use_partial_fit, - train_on_single_class_y, + DiscreteNaiveBayes, use_partial_fit, train_on_single_class_y, ): # Most array attributes of a discrete naive Bayes classifier should have a # first-axis length equal to the number of classes. Exceptions include: @@ -961,185 +959,242 @@ def test_n_features_deprecation(Estimator): def test_cwnb_union(): # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [0]), - ('g2', GaussianNB(), [1])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] + ) clf2 = GaussianNB() clf1.fit(X, y) clf2.fit(X, y) assert_array_almost_equal(clf1.predict(X), clf2.predict(X), 8) assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) - assert_array_almost_equal(clf1.predict_log_proba(X), - clf2.predict_log_proba(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [0]), - ('b2', BernoulliNB(), [1, 2])]) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + ) clf2 = BernoulliNB() clf1.fit(X1, y1) clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) - assert_array_almost_equal(clf1.predict_log_proba(X1), - clf2.predict_log_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) # A union of BernoulliNB's yields the same prediction a single BernoulliNB # (partial_fit) - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [0]), - ('b2', BernoulliNB(), [1, 2])]) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + ) clf2 = BernoulliNB() clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) clf1.partial_fit(X1[5:], y1[5:]) clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) - assert_array_almost_equal(clf1.predict_log_proba(X1), - clf2.predict_log_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) # A union of several different NB's is permutation-invariant - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [3]), - ('g1', GaussianNB(), [0]), - ('m1', MultinomialNB(), [0, 2]), - ('b2', BernoulliNB(), [1]) - ]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [3]), + ("g1", GaussianNB(), [0]), + ("m1", MultinomialNB(), [0, 2]), + ("b2", BernoulliNB(), [1]), + ] + ) # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers - clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [3]), - ('g1', GaussianNB(), [1]), - ('m1', MultinomialNB(), [1, 0]), - ('b2', BernoulliNB(), [2]) - ]) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [3]), + ("g1", GaussianNB(), [1]), + ("m1", MultinomialNB(), [1, 0]), + ("b2", BernoulliNB(), [2]), + ] + ) clf1.fit(X2[:, [0, 1, 2, 3, 4]], y2) # (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) clf2.fit(X2[:, [2, 0, 1, 3, 4]], y2) # (0, 1, 2, 3, 4) <- (2, 0, 1, 3, 4) - assert_array_almost_equal(clf1.predict_proba(X2), - clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), 8) - assert_array_almost_equal(clf1.predict_log_proba(X2), - clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), 8) - assert_array_almost_equal(clf1.predict(X2), - clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8) + assert_array_almost_equal( + clf1.predict_proba(X2), clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), 8 + ) + assert_array_almost_equal( + clf1.predict_log_proba(X2), clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), 8 + ) + assert_array_almost_equal(clf1.predict(X2), clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8) def test_cwnb_estimators_1(): # Subestimators spec: cols can be lists of int or lists of str, if DataFrame - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), ['col1']), - ('g2', GaussianNB(), ['col0', 'col1'])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), ["col1"]), + ("g2", GaussianNB(), ["col0", "col1"]), + ] + ) clf1.fit(X, y) clf2.fit(Xdf, y) - assert_array_almost_equal(clf1.predict_log_proba(X), - clf2.predict_log_proba(Xdf), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(Xdf), 8) msg = "A column-vector y was passed when a 1d array was expected" with pytest.warns(DataConversionWarning, match=msg): clf2.fit(Xdf, ydf) - assert_array_almost_equal(clf1.predict_log_proba(X), - clf2.predict_log_proba(Xdf), 8) + assert_array_almost_equal( + clf1.predict_log_proba(X), clf2.predict_log_proba(Xdf), 8 + ) # Subestimators spec: repeated col ints have the same effect as repeating data - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1, 1]), - ('b1', BernoulliNB(), [0, 0, 1, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [0, 1]), - ('b1', BernoulliNB(), [2, 3, 4, 5])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] + ) clf1.fit(X1, y1) clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) - assert_array_almost_equal(clf1.predict_log_proba(X1), - clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8) + assert_array_almost_equal( + clf1.predict_log_proba(X1), clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8 + ) # Subestimators spec: empty cols have the same effect as an absent estimator - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), []), - ('g3', GaussianNB(), [0, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g3', GaussianNB(), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB(), []), + ("g3", GaussianNB(), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + ) clf1.fit(X1, y1) clf2.fit(X1, y1) - assert_array_almost_equal(clf1.predict_log_proba(X1), - clf2.predict_log_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) # Empty-columns estimators are passed to estimators_ and the numbers match assert len(clf1.estimators) == len(clf1.estimators_) == 3 assert len(clf2.estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_['g2']) + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: empty cols have the same effect as an absent estimator # when callable columns produce the empty set. select_none = make_column_selector(pattern="qwerasdf") - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), select_none), - ('g3', GaussianNB(), [0, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g3', GaussianNB(), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB(), select_none), + ("g3", GaussianNB(), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + ) clf1.fit(Xdf, y) clf2.fit(Xdf, y) - assert_array_almost_equal(clf1.predict_log_proba(Xdf), - clf2.predict_log_proba(Xdf), 8) + assert_array_almost_equal( + clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 + ) # Empty-columns estimators are passed to estimators_ and the numbers match assert len(clf1.estimators) == len(clf1.estimators_) == 3 assert len(clf2.estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_['g2']) + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: test callable columns select_int = make_column_selector(dtype_include=np.int_) select_float = make_column_selector(dtype_include=np.float_) Xdf2 = Xdf - Xdf2['col3'] = np.exp(Xdf['col0']) - 0.5 * Xdf['col1'] - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), ['col3']), - ('m1', BernoulliNB(), ['col0', 'col1'])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), select_float), - ('g2', BernoulliNB(), select_int)]) + Xdf2["col3"] = np.exp(Xdf["col0"]) - 0.5 * Xdf["col1"] + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), ["col3"]), + ("m1", BernoulliNB(), ["col0", "col1"]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), select_float), + ("g2", BernoulliNB(), select_int), + ] + ) clf1.fit(Xdf, y) clf2.fit(Xdf, y) - assert_array_almost_equal(clf1.predict_log_proba(Xdf), - clf2.predict_log_proba(Xdf), 8) + assert_array_almost_equal( + clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 + ) # Subestimators spec: error on repeated names - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g1', GaussianNB(), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] + ) msg = "Names provided are not unique" with pytest.raises(ValueError, match=msg): clf1.fit(X, y) - clf1 = ColumnwiseNB(estimators=[['g1', GaussianNB(), [1]], - ['g2', GaussianNB(), [0, 1]]]) + clf1 = ColumnwiseNB( + estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] + ) clf1.fit(X, y) def test_cwnb_estimators_2(): # Subestimators spec: error when some don't support _joint_log_likelihood class notNB(BaseEstimator): - def __init__(self): pass - def fit(self, X, y): pass - def partial_fit(self, X, y): pass + def __init__(self): + pass + + def fit(self, X, y): + pass + + def partial_fit(self, X, y): + pass + # def _joint_log_likelihood(self, X): pass - def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], - ['g2', GaussianNB(), [0]]]) + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) # Subestimators spec: error when some don't support fit class notNB(BaseEstimator): - def __init__(self): pass + def __init__(self): + pass + # def fit(self, X, y): pass - def partial_fit(self, X, y): pass - def _joint_log_likelihood(self, X): pass - def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], - ['g2', GaussianNB(), [0]]]) + def partial_fit(self, X, y): + pass + + def _joint_log_likelihood(self, X): + pass + + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.fit(X, y) # Subestimators spec: error when some don't support partial_fit class notNB(BaseEstimator): - def __init__(self): pass - def fit(self, X, y): pass + def __init__(self): + pass + + def fit(self, X, y): + pass + # def partial_fit(self, X, y): pass - def _joint_log_likelihood(self, X): pass - def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[['g1', notNB(), [1]], - ['g2', GaussianNB(), [0]]]) + def _joint_log_likelihood(self, X): + pass + + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) @@ -1147,51 +1202,58 @@ def predict(self, X): pass def test_cwnb_prior(): # prior spec: error when negative, sum!=1 or bad length - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])], - priors=np.array([-0.25, 1.25])) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([-0.25, 1.25]), + ) msg = "Priors must be non-negative." with pytest.raises(ValueError, match=msg): clf1.fit(X, y) - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])], - priors=np.array([0.25, .7])) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([0.25, 0.7]), + ) msg = "The sum of the priors should be 1." with pytest.raises(ValueError, match=msg): clf1.fit(X, y) - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])], - priors=np.array([0.25, 0.25, 0.25, 0.25])) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([0.25, 0.25, 0.25, 0.25]), + ) msg = "Number of priors must match number of classes." with pytest.raises(ValueError, match=msg): clf1.fit(X, y) # prior spec: specified prior equals calculated and subestimators' priors - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])], - priors=np.array([.5, .5])) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([0.5, 0.5]), + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) clf1.fit(X, y) clf2.fit(X, y) assert clf2.priors is None - assert_array_almost_equal(clf1.class_prior_, - clf1.named_estimators_['g1'].class_prior_, - 8) - assert_array_almost_equal(clf1.class_prior_, - clf1.named_estimators_['g2'].class_prior_, - 8) + assert_array_almost_equal( + clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_, 8 + ) + assert_array_almost_equal( + clf1.class_prior_, clf1.named_estimators_["g2"].class_prior_, 8 + ) assert_array_almost_equal(clf1.class_prior_, clf2.class_prior_, 8) def test_cwnb_zero_prior(): # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator - clf1 = ColumnwiseNB(estimators=[ - ('g1', GaussianNB(), [1, 3, 5]), - ('g2', GaussianNB(priors=np.array([.5, 0, .5])), [0, 1]) - ]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1, 3, 5]), + ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), + ] + ) clf1.fit(X2, y2) msg = "divide by zero encountered in log" with pytest.warns(RuntimeWarning, match=msg): @@ -1209,10 +1271,12 @@ def test_cwnb_zero_prior(): # is not tested here. # P(y)=0 in two subestimators results in P(y|x)=0 of meta-estimator - clf1 = ColumnwiseNB(estimators=[ - ('g1', GaussianNB(priors=np.array([.6, 0, .4])), [1, 3, 5]), - ('g2', GaussianNB(priors=np.array([.5, 0, .5])), [0, 1]) - ]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), + ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), + ] + ) clf1.fit(X2, y2) with pytest.warns(RuntimeWarning, match=msg): p = clf1.predict_proba(X2)[:, 1] @@ -1228,90 +1292,112 @@ def test_cwnb_zero_prior(): def test_cwnb_sample_weight(): # weights in fit have no effect if all ones weights = [1, 1, 1, 1, 1, 1] - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(), [1]), - ('g2', GaussianNB(), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) clf1.fit(X, y, sample_weight=weights) clf2.fit(X, y) - assert_array_almost_equal(clf1._joint_log_likelihood(X), - clf2._joint_log_likelihood(X), 8) - assert_array_almost_equal(clf1.predict_log_proba(X), - clf2.predict_log_proba(X), 8) - assert_array_equal(clf1.predict(X), - clf2.predict(X)) + assert_array_almost_equal( + clf1._joint_log_likelihood(X), clf2._joint_log_likelihood(X), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), clf2.predict(X)) # weights in partial_fit have no effect if all ones - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) - clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) clf2.partial_fit(X2, y2, classes=np.unique(y2)) - assert_array_almost_equal(clf1._joint_log_likelihood(X2), - clf2._joint_log_likelihood(X2), 8) - assert_array_almost_equal(clf1.predict_log_proba(X2), - clf2.predict_log_proba(X2), 8) - assert_array_equal(clf1.predict(X2), - clf2.predict(X2)) + assert_array_almost_equal( + clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) # weights in fit have the same effect as repeating data weights = [1, 2, 3, 1, 4, 2] idx = list(chain(*([i] * w for i, w in enumerate(weights)))) # var_smoothing=0.0 is for maximum precision in dealing with a small sample - clf1 = ColumnwiseNB(estimators=[('g1', GaussianNB(var_smoothing=0.0), [1]), - ('g2', GaussianNB(var_smoothing=0.0), [0, 1])]) - clf2 = ColumnwiseNB(estimators=[('g1', GaussianNB(var_smoothing=0.0), [1]), - ('g2', GaussianNB(var_smoothing=0.0), [0, 1])]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(var_smoothing=0.0), [1]), + ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(var_smoothing=0.0), [1]), + ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), + ] + ) clf1.fit(X, y, sample_weight=weights) clf2.fit(X[idx], y[idx]) - assert_array_almost_equal(clf1._joint_log_likelihood(X), - clf2._joint_log_likelihood(X), 8) - assert_array_almost_equal(clf1.predict_log_proba(X), - clf2.predict_log_proba(X), 8) - assert_array_equal(clf1.predict(X), - clf2.predict(X), 8) - for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_almost_equal( + clf1._joint_log_likelihood(X), clf2._joint_log_likelihood(X), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), clf2.predict(X), 8) + for attr_name in ("class_count_", "class_prior_", "classes_"): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) # weights in partial_fit have the same effect as repeating data - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) - clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) clf2.partial_fit(X2[idx], y2[idx], classes=np.unique(y2)) - assert_array_equal(clf1._joint_log_likelihood(X2), - clf2._joint_log_likelihood(X2)) - assert_array_equal(clf1.predict_log_proba(X2), - clf2.predict_log_proba(X2)) - assert_array_equal(clf1.predict(X2), - clf2.predict(X2)) - for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_equal(clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2)) + assert_array_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2)) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) def test_cwnb_partial_fit(): # partial_fit: consecutive calls yield the same prediction as a single call - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) - clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf2 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) clf1.partial_fit(X2, y2, classes=np.unique(y2)) clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) clf2.partial_fit(X2[4:], y2[4:]) - assert_array_almost_equal(clf1._joint_log_likelihood(X2), - clf2._joint_log_likelihood(X2), 8) - assert_array_almost_equal(clf1.predict_log_proba(X2), - clf2.predict_log_proba(X2), 8) - assert_array_equal(clf1.predict(X2), - clf2.predict(X2)) - for attr_name in ('class_count_', 'class_prior_', 'classes_'): + assert_array_almost_equal( + clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) # partial_fit: error when classes are not provided at the first call - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) msg = ".lasses must be passed on the first call to partial_fit" with pytest.raises(ValueError, match=msg): clf1.partial_fit(X2, y2) @@ -1319,73 +1405,92 @@ def test_cwnb_partial_fit(): def test_cwnb_consistency(): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3])]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) clf1.fit(X2, y2) for se in clf1.named_estimators_: - assert_array_almost_equal(clf1.class_count_, - clf1.named_estimators_[se].class_count_, 8) - assert_array_almost_equal(clf1.classes_, - clf1.named_estimators_[se].classes_, 8) - assert_array_almost_equal(np.log(clf1.class_prior_), - clf1.named_estimators_[se].class_log_prior_, 8) + assert_array_almost_equal( + clf1.class_count_, clf1.named_estimators_[se].class_count_, 8 + ) + assert_array_almost_equal(clf1.classes_, clf1.named_estimators_[se].classes_, 8) + assert_array_almost_equal( + np.log(clf1.class_prior_), clf1.named_estimators_[se].class_log_prior_, 8 + ) def test_cwnb_params(): # Can get and set subestimators' parameters through name__paramname # clone() works on ColumnwiseNB - clf1 = ColumnwiseNB(estimators=[ - ('b1', BernoulliNB(alpha=.2, binarize=2), [1]), - ('m1', MultinomialNB(class_prior=[.2, .2, .6]), [0, 2, 3]) - ]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), + ("m1", MultinomialNB(class_prior=[0.2, 0.2, 0.6]), [0, 2, 3]), + ] + ) clf1.fit(X2, y2) p = clf1.get_params(deep=True) - assert p['b1__alpha'] == .2 - assert p['b1__binarize'] == 2 - assert p['m1__class_prior'] == [.2, .2, .6] - clf1.set_params(b1__alpha=123, m1__class_prior=[.3, .3, .4]) + assert p["b1__alpha"] == 0.2 + assert p["b1__binarize"] == 2 + assert p["m1__class_prior"] == [0.2, 0.2, 0.6] + clf1.set_params(b1__alpha=123, m1__class_prior=[0.3, 0.3, 0.4]) assert clf1.estimators[0][1].alpha == 123 - assert_array_equal(clf1.estimators[1][1].class_prior, [.3, .3, .4]) + assert_array_equal(clf1.estimators[1][1].class_prior, [0.3, 0.3, 0.4]) # After cloning and fitting, we can check through named_estimators, which # maps to fitted estimators_: clf2 = clone(clf1).fit(X2, y2) - assert clf2.named_estimators_['b1'].alpha == 123 - assert_array_equal(clf2.named_estimators_['m1'].class_prior, [.3, .3, .4]) - assert (id(clf2.named_estimators_['b1']) != - id(clf1.named_estimators_['b1'])) + assert clf2.named_estimators_["b1"].alpha == 123 + assert_array_equal(clf2.named_estimators_["m1"].class_prior, [0.3, 0.3, 0.4]) + assert id(clf2.named_estimators_["b1"]) != id(clf1.named_estimators_["b1"]) def test_cwnb_n_jobs(): # n_jobs: same results wether with it or without - clf1 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('b2', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3]), - ('m3', MultinomialNB(), slice(10, None))], - n_jobs=4) - clf2 = ColumnwiseNB(estimators=[('b1', BernoulliNB(binarize=2), [1]), - ('b2', BernoulliNB(binarize=2), [1]), - ('m1', MultinomialNB(), [0, 2, 3]), - ('m3', MultinomialNB(), slice(10, None))]) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("b2", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ("m3", MultinomialNB(), slice(10, None)), + ], + n_jobs=4, + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("b2", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ("m3", MultinomialNB(), slice(10, None)), + ] + ) clf1.partial_fit(X2, y2, classes=np.unique(y2)) clf2.partial_fit(X2, y2, classes=np.unique(y2)) - assert_array_almost_equal(clf1._joint_log_likelihood(X2), - clf2._joint_log_likelihood(X2), 8) - assert_array_almost_equal(clf1.predict_log_proba(X2), - clf2.predict_log_proba(X2), 8) - assert_array_equal(clf1.predict(X2), - clf2.predict(X2)) + assert_array_almost_equal( + clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) def test_cwnb_example(): # Test the Example from ColumnwiseNB docstring in naive_bayes.py import numpy as np + rng = np.random.RandomState(1) X = rng.randint(5, size=(6, 100)) y = np.array([0, 0, 1, 1, 2, 2]) from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB - clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), - ('mnb2', MultinomialNB(), [3, 4]), - ('gnb1', GaussianNB(), [5])]) + + clf = ColumnwiseNB( + estimators=[ + ("mnb1", MultinomialNB(), [0, 1]), + ("mnb2", MultinomialNB(), [3, 4]), + ("gnb1", GaussianNB(), [5]), + ] + ) clf.fit(X, y) clf.predict(X) From 861a573725f8c10d8b954d31881a0df18e4e3451 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 04:29:47 -0500 Subject: [PATCH 006/102] black formatting compliance. --- sklearn/naive_bayes.py | 3 +-- sklearn/tests/test_naive_bayes.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 8e13b9d308771..4645ca9883ae6 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1610,8 +1610,7 @@ def _check_X(self, X): return X def _joint_log_likelihood(self, X): - """Calculate the meta-estimator's joint log likelihood ``log P(x,c)``. - """ + """Calculate the meta-estimator's joint log likelihood ``log P(x,c)``.""" # Because data must follow the same path as it would in subestimators, # _jll_one(estimatorNB, X) passes it through estimatorNB._check_X to # match the implementation of _BaseNB.predict_log_proba. diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index edac66fa80f6f..c728f99b253e0 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -415,7 +415,9 @@ def test_discretenb_sample_weight_multiclass(DiscreteNaiveBayes): @pytest.mark.parametrize("use_partial_fit", [False, True]) @pytest.mark.parametrize("train_on_single_class_y", [False, True]) def test_discretenb_degenerate_one_class_case( - DiscreteNaiveBayes, use_partial_fit, train_on_single_class_y, + DiscreteNaiveBayes, + use_partial_fit, + train_on_single_class_y, ): # Most array attributes of a discrete naive Bayes classifier should have a # first-axis length equal to the number of classes. Exceptions include: From 9b73275c673c8b1c89e53fca270971a5e68e5db4 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 15:54:21 -0500 Subject: [PATCH 007/102] ColumnwiseNB: added _required_parameters = [estimators] --- sklearn/naive_bayes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 4645ca9883ae6..c4128b1ac700f 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1587,6 +1587,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): [0 0 1 0 2 2] """ + _required_parameters = ["estimators"] + def _log_message(self, name, idx, total): if not self.verbose: return None From fa665756ac5fa39c270dd1fda13c90590b45c275 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 17:01:38 -0500 Subject: [PATCH 008/102] Dirty trick with ColumnwiseNB._required_parameters to pass tests --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index c4128b1ac700f..190e607ae1f28 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1587,7 +1587,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): [0 0 1 0 2 2] """ - _required_parameters = ["estimators"] + _required_parameters = ["estimatorsNB"] def _log_message(self, name, idx, total): if not self.verbose: From 12677e7f3081fd7ad610caf95aa3374774d2a3f3 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 20:45:10 -0500 Subject: [PATCH 009/102] ColumnwiseNB docstring: added extended summary --- sklearn/naive_bayes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 190e607ae1f28..628002f9ffcc7 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1464,7 +1464,12 @@ def _jll_one(estimator, X): class ColumnwiseNB(_BaseNB, _BaseComposition): """ - Column-wise Naive Bayes estimator. + Column-wise Naive Bayes meta-estimator. + + This estimator combines various naive Bayes estimators by applying them + to different column subsets of the input and joining their predictions + according to the naive Bayes assumption. This is useful when features are + heterogeneous and follow different kinds of distributions. Parameters ---------- From e748cf1e9ddf79457681e4b03b3d3fa593df906b Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 20:46:24 -0500 Subject: [PATCH 010/102] ColumnwiseNB test issue: added to VALIDATE_ESTIMATOR_INIT exclusion list --- sklearn/tests/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 350e1e95d9882..39b2dfa6ecee6 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -419,6 +419,7 @@ def test_transformers_get_feature_names_out(transformer): "SGDOneClassSVM", "TheilSenRegressor", "TweedieRegressor", + "ColumnwiseNB", ] VALIDATE_ESTIMATOR_INIT = set(VALIDATE_ESTIMATOR_INIT) From 8b7cfc8157b4fb9e370f327df5d47b0f80fa445d Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 22:56:33 -0500 Subject: [PATCH 011/102] ColumnwiseNB: rename 'estimators' into 'estimatorNBs' --- sklearn/naive_bayes.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 628002f9ffcc7..9c4284146c0a0 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1473,7 +1473,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Parameters ---------- - estimators : list of tuples + estimatorNBs : list of tuples List of (name, estimatorNB, columns) tuples specifying the naive Bayes estimators to be combined into a single naive Bayes meta-estimator. @@ -1592,15 +1592,15 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): [0 0 1 0 2 2] """ - _required_parameters = ["estimatorsNB"] + _required_parameters = ["estimatorNBs"] def _log_message(self, name, idx, total): if not self.verbose: return None return "(%d of %d) Processing %s" % (idx, total, name) - def __init__(self, estimators, priors=None, n_jobs=None, verbose=False): - self.estimators = estimators + def __init__(self, estimatorNBs, priors=None, n_jobs=None, verbose=False): + self.estimatorNBs = estimatorNBs self.priors = priors self.n_jobs = n_jobs self.verbose = verbose @@ -1634,12 +1634,12 @@ def _joint_log_likelihood(self, X): def _validate_estimators(self, check_partial=False): # Check if estimators have fit/partial_fit and jll methods # Validate estimator names via _BaseComposition._validate_names(self, names) - if not self.estimators: + if not self.estimatorNBs: raise ValueError( "A list of naive Bayes estimators must be provided " "in the form [(name, estimatorNB, columns), ... ]." ) - names, estimators, _ = zip(*self.estimators) + names, estimators, _ = zip(*self.estimatorNBs) for e in estimators: if (not check_partial) and ( not (hasattr(e, "fit") and hasattr(e, "_joint_log_likelihood")) @@ -1671,7 +1671,7 @@ def _validate_column_callables(self, X): # ColumnTransformer code. all_columns = [] estimator_to_input_indices = {} - for name, _, columns in self.estimators: + for name, _, columns in self.estimatorNBs: if callable(columns): columns = columns(X) all_columns.append(columns) @@ -1743,7 +1743,7 @@ def _iter(self, *, fitted=False, replace_strings=False): else: yield (name, estimator, cols) else: # fitted=False - for (name, estimator, _), cols in zip(self.estimators, self._columns): + for (name, estimator, _), cols in zip(self.estimatorNBs, self._columns): if replace_strings and _is_empty_column_selection(cols): continue else: @@ -1948,14 +1948,14 @@ def _estimators(self): which expects lists of tuples of len 2. """ # Implemented in the image and likeness of ColumnTranformer._transformers - return [(name, e) for name, e, _ in self.estimators] + return [(name, e) for name, e, _ in self.estimatorNBs] @_estimators.setter def _estimators(self, value): # Implemented in the image and likeness of ColumnTranformer._transformers # TODO: Is renaming or changing the order legal? Swap `name` and `_`? - self.estimators = [ - (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimators) + self.estimatorNBs = [ + (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimatorNBs) ] def get_params(self, deep=True): From 7cb843086297a62af11acf8d08ac6cc73d4d7b6f Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 23:15:29 -0500 Subject: [PATCH 012/102] Rename 'estimators' into 'estimatorNBs' in test_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 102 +++++++++++++++--------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index c728f99b253e0..f78447c6d4d03 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -962,7 +962,7 @@ def test_n_features_deprecation(Estimator): def test_cwnb_union(): # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] + estimatorNBs=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] ) clf2 = GaussianNB() clf1.fit(X, y) @@ -973,7 +973,7 @@ def test_cwnb_union(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) clf1 = ColumnwiseNB( - estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + estimatorNBs=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.fit(X1, y1) @@ -985,7 +985,7 @@ def test_cwnb_union(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB # (partial_fit) clf1 = ColumnwiseNB( - estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + estimatorNBs=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) @@ -997,7 +997,7 @@ def test_cwnb_union(): # A union of several different NB's is permutation-invariant clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [0]), ("m1", MultinomialNB(), [0, 2]), @@ -1006,7 +1006,7 @@ def test_cwnb_union(): ) # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [1]), ("m1", MultinomialNB(), [1, 0]), @@ -1027,10 +1027,10 @@ def test_cwnb_union(): def test_cwnb_estimators_1(): # Subestimators spec: cols can be lists of int or lists of str, if DataFrame clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), ["col1"]), ("g2", GaussianNB(), ["col0", "col1"]), ] @@ -1047,10 +1047,10 @@ def test_cwnb_estimators_1(): # Subestimators spec: repeated col ints have the same effect as repeating data clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] + estimatorNBs=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] ) clf2 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] + estimatorNBs=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] ) clf1.fit(X1, y1) clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) @@ -1060,37 +1060,37 @@ def test_cwnb_estimators_1(): # Subestimators spec: empty cols have the same effect as an absent estimator clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), []), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(X1, y1) clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.estimators) == len(clf1.estimators_) == 3 - assert len(clf2.estimators) == len(clf2.estimators_) == 2 + assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 + assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: empty cols have the same effect as an absent estimator # when callable columns produce the empty set. select_none = make_column_selector(pattern="qwerasdf") clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), select_none), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(Xdf, y) clf2.fit(Xdf, y) @@ -1098,10 +1098,10 @@ def test_cwnb_estimators_1(): clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 ) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.estimators) == len(clf1.estimators_) == 3 - assert len(clf2.estimators) == len(clf2.estimators_) == 2 + assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 + assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: test callable columns select_int = make_column_selector(dtype_include=np.int_) @@ -1109,13 +1109,13 @@ def test_cwnb_estimators_1(): Xdf2 = Xdf Xdf2["col3"] = np.exp(Xdf["col0"]) - 0.5 * Xdf["col1"] clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), ["col3"]), ("m1", BernoulliNB(), ["col0", "col1"]), ] ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), select_float), ("g2", BernoulliNB(), select_int), ] @@ -1128,14 +1128,14 @@ def test_cwnb_estimators_1(): # Subestimators spec: error on repeated names clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] ) msg = "Names provided are not unique" with pytest.raises(ValueError, match=msg): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] + estimatorNBs=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] ) clf1.fit(X, y) @@ -1156,7 +1156,7 @@ def partial_fit(self, X, y): def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) @@ -1176,7 +1176,7 @@ def _joint_log_likelihood(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.fit(X, y) @@ -1196,7 +1196,7 @@ def _joint_log_likelihood(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) @@ -1205,7 +1205,7 @@ def predict(self, X): def test_cwnb_prior(): # prior spec: error when negative, sum!=1 or bad length clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([-0.25, 1.25]), ) msg = "Priors must be non-negative." @@ -1213,7 +1213,7 @@ def test_cwnb_prior(): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.7]), ) msg = "The sum of the priors should be 1." @@ -1221,7 +1221,7 @@ def test_cwnb_prior(): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.25, 0.25, 0.25]), ) msg = "Number of priors must match number of classes." @@ -1230,11 +1230,11 @@ def test_cwnb_prior(): # prior spec: specified prior equals calculated and subestimators' priors clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.5, 0.5]), ) clf2 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y) clf2.fit(X, y) @@ -1251,7 +1251,7 @@ def test_cwnb_prior(): def test_cwnb_zero_prior(): # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), ] @@ -1274,7 +1274,7 @@ def test_cwnb_zero_prior(): # P(y)=0 in two subestimators results in P(y|x)=0 of meta-estimator clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), ] @@ -1295,10 +1295,10 @@ def test_cwnb_sample_weight(): # weights in fit have no effect if all ones weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y, sample_weight=weights) clf2.fit(X, y) @@ -1310,13 +1310,13 @@ def test_cwnb_sample_weight(): # weights in partial_fit have no effect if all ones clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1334,13 +1334,13 @@ def test_cwnb_sample_weight(): idx = list(chain(*([i] * w for i, w in enumerate(weights)))) # var_smoothing=0.0 is for maximum precision in dealing with a small sample clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] @@ -1357,13 +1357,13 @@ def test_cwnb_sample_weight(): # weights in partial_fit have the same effect as repeating data clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1380,10 +1380,10 @@ def test_cwnb_sample_weight(): def test_cwnb_partial_fit(): # partial_fit: consecutive calls yield the same prediction as a single call clf1 = ColumnwiseNB( - estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf2 = ColumnwiseNB( - estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf1.partial_fit(X2, y2, classes=np.unique(y2)) clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) @@ -1398,7 +1398,7 @@ def test_cwnb_partial_fit(): # partial_fit: error when classes are not provided at the first call clf1 = ColumnwiseNB( - estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) msg = ".lasses must be passed on the first call to partial_fit" with pytest.raises(ValueError, match=msg): @@ -1408,7 +1408,7 @@ def test_cwnb_partial_fit(): def test_cwnb_consistency(): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1428,7 +1428,7 @@ def test_cwnb_params(): # Can get and set subestimators' parameters through name__paramname # clone() works on ColumnwiseNB clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), ("m1", MultinomialNB(class_prior=[0.2, 0.2, 0.6]), [0, 2, 3]), ] @@ -1439,8 +1439,8 @@ def test_cwnb_params(): assert p["b1__binarize"] == 2 assert p["m1__class_prior"] == [0.2, 0.2, 0.6] clf1.set_params(b1__alpha=123, m1__class_prior=[0.3, 0.3, 0.4]) - assert clf1.estimators[0][1].alpha == 123 - assert_array_equal(clf1.estimators[1][1].class_prior, [0.3, 0.3, 0.4]) + assert clf1.estimatorNBs[0][1].alpha == 123 + assert_array_equal(clf1.estimatorNBs[1][1].class_prior, [0.3, 0.3, 0.4]) # After cloning and fitting, we can check through named_estimators, which # maps to fitted estimators_: clf2 = clone(clf1).fit(X2, y2) @@ -1452,7 +1452,7 @@ def test_cwnb_params(): def test_cwnb_n_jobs(): # n_jobs: same results wether with it or without clf1 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1461,7 +1461,7 @@ def test_cwnb_n_jobs(): n_jobs=4, ) clf2 = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1488,7 +1488,7 @@ def test_cwnb_example(): from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB clf = ColumnwiseNB( - estimators=[ + estimatorNBs=[ ("mnb1", MultinomialNB(), [0, 1]), ("mnb2", MultinomialNB(), [3, 4]), ("gnb1", GaussianNB(), [5]), From 2abac32214d7a5f1ff5b3f6df8633b12aca9afb5 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 22 Feb 2022 23:16:55 -0500 Subject: [PATCH 013/102] ColumnwiseNB: rename 'estimators' in the example too --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 9c4284146c0a0..507c4b2e81b83 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1581,11 +1581,11 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): >>> X = rng.randint(5, size=(6, 100)) >>> y = np.array([0, 0, 1, 1, 2, 2]) >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB - >>> clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + >>> clf = ColumnwiseNB(estimatorNBs=[('mnb1', MultinomialNB(), [0, 1]), ... ('mnb2', MultinomialNB(), [3, 4]), ... ('gnb1', GaussianNB(), [5])]) >>> clf.fit(X, y) - ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ColumnwiseNB(estimatorNBs=[('mnb1', MultinomialNB(), [0, 1]), ('mnb2', MultinomialNB(), [3, 4]), ('gnb1', GaussianNB(), [5])]) >>> print(clf.predict(X)) From c9a5e1d1aad774fff628e467096673a757b68c90 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 23 Feb 2022 00:10:01 -0500 Subject: [PATCH 014/102] Added pytest skip when no pandas to test_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 80 +++++++++++++++---------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index f78447c6d4d03..ffc75d2d5f06c 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -3,7 +3,6 @@ import numpy as np import scipy.sparse import pytest -import pandas as pd from itertools import chain from sklearn.datasets import load_digits, load_iris @@ -34,10 +33,6 @@ X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) y = np.array([1, 1, 1, 2, 2, 2]) -# Same as above, but a dataframe -Xdf = pd.DataFrame(data=X, columns=["col0", "col1"]) -ydf = pd.DataFrame({"target": y}) - # A bit more random tests rng = np.random.RandomState(0) X1 = rng.normal(size=(10, 3)) @@ -1024,7 +1019,11 @@ def test_cwnb_union(): assert_array_almost_equal(clf1.predict(X2), clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8) -def test_cwnb_estimators_1(): +def test_cwnb_estimators_pandas(): + pd = pytest.importorskip("pandas") + Xdf = pd.DataFrame(data=X, columns=["col0", "col1"]) + ydf = pd.DataFrame({"target": y}) + # Subestimators spec: cols can be lists of int or lists of str, if DataFrame clf1 = ColumnwiseNB( estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] @@ -1045,42 +1044,8 @@ def test_cwnb_estimators_1(): clf1.predict_log_proba(X), clf2.predict_log_proba(Xdf), 8 ) - # Subestimators spec: repeated col ints have the same effect as repeating data - clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] - ) - clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] - ) - clf1.fit(X1, y1) - clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) - assert_array_almost_equal( - clf1.predict_log_proba(X1), clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8 - ) - - # Subestimators spec: empty cols have the same effect as an absent estimator - clf1 = ColumnwiseNB( - estimatorNBs=[ - ("g1", GaussianNB(), [1]), - ("g2", GaussianNB(), []), - ("g3", GaussianNB(), [0, 1]), - ] - ) - clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] - ) - clf1.fit(X1, y1) - clf2.fit(X1, y1) - assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) - # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 - assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 - # No cloning of the empty-columns estimators took place: - assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) - # Subestimators spec: empty cols have the same effect as an absent estimator # when callable columns produce the empty set. - select_none = make_column_selector(pattern="qwerasdf") clf1 = ColumnwiseNB( estimatorNBs=[ @@ -1126,6 +1091,41 @@ def test_cwnb_estimators_1(): clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 ) + +def test_cwnb_estimators_1(): + # Subestimators spec: repeated col ints have the same effect as repeating data + clf1 = ColumnwiseNB( + estimatorNBs=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] + ) + clf2 = ColumnwiseNB( + estimatorNBs=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] + ) + clf1.fit(X1, y1) + clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) + assert_array_almost_equal( + clf1.predict_log_proba(X1), clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8 + ) + + # Subestimators spec: empty cols have the same effect as an absent estimator + clf1 = ColumnwiseNB( + estimatorNBs=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB(), []), + ("g3", GaussianNB(), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + ) + clf1.fit(X1, y1) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) + # Empty-columns estimators are passed to estimators_ and the numbers match + assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 + assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 + # No cloning of the empty-columns estimators took place: + assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) + # Subestimators spec: error on repeated names clf1 = ColumnwiseNB( estimatorNBs=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] From 79f980fc6257353bf678dcdd5491986e7f1f2ea6 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 23 Feb 2022 02:34:24 -0500 Subject: [PATCH 015/102] ColumnwiseNB: update class prior AFTER update fitted estimators --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 507c4b2e81b83..4496720be8579 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1842,7 +1842,6 @@ def fit(self, X, y, sample_weight=None): counts[i] = (weights * (column_or_1d(y) == c)).sum() self.class_count_ = counts self.n_classes_ = len(self.classes_) - self._update_class_prior() estimators = list(self._iter(fitted=False, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( @@ -1857,6 +1856,7 @@ def fit(self, X, y, sample_weight=None): for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) ) self._update_fitted_estimators(fitted_estimators) + self._update_class_prior() return self def partial_fit(self, X, y, classes=None, sample_weight=None): @@ -1922,7 +1922,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): self.class_count_ = counts else: self.class_count_ += counts - self._update_class_prior() estimators = list(self._iter(fitted=not first_call, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( @@ -1938,6 +1937,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) ) self._update_fitted_estimators(fitted_estimators) + self._update_class_prior() return self @property From c2cd353519edf7aff99593b5df1c4ef10875548e Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 23 Feb 2022 03:04:09 -0500 Subject: [PATCH 016/102] test_naive_bayes.py Added more tests to improve coverage --- sklearn/tests/test_naive_bayes.py | 56 ++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index ffc75d2d5f06c..02b2bb8e47b5a 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1141,6 +1141,22 @@ def test_cwnb_estimators_1(): def test_cwnb_estimators_2(): + # Subestimators spec: error on empty list + clf = ColumnwiseNB( + estimatorNBs=[], + ) + msg = "A list of naive Bayes estimators must be provided*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) + + # Subestimators spec: error on None + clf = ColumnwiseNB( + estimatorNBs=None, + ) + msg = "A list of naive Bayes estimators must be provided*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) + # Subestimators spec: error when some don't support _joint_log_likelihood class notNB(BaseEstimator): def __init__(self): @@ -1201,6 +1217,20 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) + # _estimators setter works + clf1 = ColumnwiseNB( + estimatorNBs=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] + ) + clf1.fit(X1, y1) + clf1._estimators = [ + ("x1", clf1.named_estimators_["g1"]), + ("x2", clf1.named_estimators_["g1"]), + ] + assert clf1.estimatorNBs[0][0] == "x1" + assert clf1.estimatorNBs[0][1] is clf1.named_estimators_["g1"] + assert clf1.estimatorNBs[1][0] == "x2" + assert clf1.estimatorNBs[1][1] is clf1.named_estimators_["g1"] + def test_cwnb_prior(): # prior spec: error when negative, sum!=1 or bad length @@ -1229,16 +1259,22 @@ def test_cwnb_prior(): clf1.fit(X, y) # prior spec: specified prior equals calculated and subestimators' priors + # prior spec: str prior ties subestimators' clf1 = ColumnwiseNB( estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.5, 0.5]), ) clf2 = ColumnwiseNB( + estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors="g1", + ) + clf3 = ColumnwiseNB( estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y) clf2.fit(X, y) - assert clf2.priors is None + clf3.fit(X, y) + assert clf3.priors is None assert_array_almost_equal( clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_, 8 ) @@ -1246,6 +1282,8 @@ def test_cwnb_prior(): clf1.class_prior_, clf1.named_estimators_["g2"].class_prior_, 8 ) assert_array_almost_equal(clf1.class_prior_, clf2.class_prior_, 8) + assert_array_almost_equal(clf1.class_prior_, clf3.class_prior_, 8) + assert_array_equal(clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_) def test_cwnb_zero_prior(): @@ -1496,3 +1534,19 @@ def test_cwnb_example(): ) clf.fit(X, y) clf.predict(X) + + +def test_cwnb_verbose(capsys): + # Setting verbose=True does not result in an error. + # This DOES NOT test if the desired output is generated. + clf = ColumnwiseNB( + estimatorNBs=[ + ("mnb1", MultinomialNB(), [0, 1]), + ("mnb2", MultinomialNB(), [3, 4]), + ("gnb1", GaussianNB(), [5]), + ], + verbose=True, + n_jobs=4, + ) + clf.fit(X2, y2) + clf.predict(X2) From de1123bf233adac813ac7cba09b80a0f9671225d Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 26 Feb 2022 23:39:12 -0500 Subject: [PATCH 017/102] Add DOC entry and corrections to DOCSTRING --- doc/modules/classes.rst | 1 + sklearn/naive_bayes.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index b7000bcf7cbb2..54b379486eda2 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1307,6 +1307,7 @@ Model validation naive_bayes.ComplementNB naive_bayes.GaussianNB naive_bayes.MultinomialNB + naive_bayes.ColumnwiseNB .. _neighbors_ref: diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 82e8173977273..e2b5468f1e0fc 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1479,10 +1479,12 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): estimators to be combined into a single naive Bayes meta-estimator. name : str - Name of the naive Bayes estimator. Like in Pipeline, FeatureUnion, - and ColumnTransformer, this allows the subestimator and its - parameters to be set using ``set_params`` and searched in grid - search. + Name of the naive Bayes estimator. Like in + :class:`~sklearn.pipeline.Pipeline`, + :class:`~sklearn.pipeline.FeatureUnion`, + and :class:`~sklearn.compose.ColumnTransformer`, this allows the + subestimator and its parameters to be set using :term:`set_params` + and searched in grid search. estimatorNB : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the @@ -1499,7 +1501,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): otherwise a 2d array will be passed to the transformer. A callable is passed the input data `X` and can return any of the above. To select multiple columns by name or dtype, you can use - :obj:`make_column_selector`. + :obj:`~sklearn.compose.make_column_selector`. priors : array-like of shape (n_classes,) or str, default=None Prior probabilities of classes. If unspecified, the priors are @@ -1520,11 +1522,11 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): ---------- estimators_ : list of tuples List of ``(name, fitted_estimatorNB, columns)`` tuples, which follow - the order of `estimators`. ``fitted_estimatorNB`` is a fitted naive + the order of `estimatorNBs`. Here, ``fitted_estimatorNB`` is a fitted naive Bayes estimator, except when ``columns`` presents an empty selection of - columns, in which case it is the original unfitted ``estimatorNB``. - Here ``columns`` is converted to a list of column indices, if the - original specification in `estimators` was a callable. + columns, in which case it is the original unfitted ``estimatorNB``. If + the original specification of ``columns`` in ``estimatorNBs`` was a + callable, then ``columns`` is converted to a list of column indices. named_estimators_ : :class:`~sklearn.utils.Bunch` Read-only attribute to access any subestimator by given name. @@ -1559,7 +1561,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): ComplementNB : Complement Naive Bayes classifier. MultinomialNB : Naive Bayes classifier for multinomial models. GaussianNB : Gaussian Naive Bayes. - ColumnTransformer : Applies transformers to columns. + :class:`~sklearn.compose.ColumnTransformer` : Applies transformers to columns. Notes ----- @@ -1600,7 +1602,7 @@ def _log_message(self, name, idx, total): return None return "(%d of %d) Processing %s" % (idx, total, name) - def __init__(self, estimatorNBs, priors=None, n_jobs=None, verbose=False): + def __init__(self, estimatorNBs, *, priors=None, n_jobs=None, verbose=False): self.estimatorNBs = estimatorNBs self.priors = priors self.n_jobs = n_jobs From e515b6a6c888abaab639c773be6f9d43584bb149 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 27 Feb 2022 03:46:47 -0500 Subject: [PATCH 018/102] ColumnwiseNB: DOCSTRING correction --- sklearn/naive_bayes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index e2b5468f1e0fc..0fcb1f31cee6f 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1550,7 +1550,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): classes_ : ndarray of shape (n_classes,) Class labels known to the classifier. - feature_names_in_ : ndarray of shape (n_features_in_,) + feature_names_in_ : ndarray of shape (`n_features_in_`,) Names of features seen during :term:`fit`. Only defined if `X` has feature names that are all strings. @@ -1567,12 +1567,16 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): ----- ColumnwiseNB combines multiple naive Bayes estimators by expressing the overall joint probability ``P(x,y)`` through ``P(x_i,y)``, the joint - probabilities of the subestimators: + probabilities of the subestimators:: + ``Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y)``, + where ``N`` denotes ``n_estimators``, the number of estimators. It is implicitly assumed that the class log priors are finite and agree - between the estimators and the subestimator: + between the estimators and the subestimator:: + ``- inf < Log P(y) = Log P(y|1) = ... = Log P(y|N)``. + The meta-estimators does not check if this condition holds. Meaningless results, including ``NaN``, may be produced by ColumnwiseNB if the class priors differ or contain a zero probability. From eba339fcd8f5376cd0d95a1b1f6f11b18ba8682d Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 31 Mar 2022 20:04:45 -0400 Subject: [PATCH 019/102] Replace ColumnwiseNB exception from init test in VALIDATE_ESTIMATOR_INIT --- sklearn/tests/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 4ff94b11793d2..d3d52433bbff9 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -415,6 +415,7 @@ def test_transformers_get_feature_names_out(transformer): VALIDATE_ESTIMATOR_INIT = [ "SGDOneClassSVM", "TheilSenRegressor", + "ColumnwiseNB", ] VALIDATE_ESTIMATOR_INIT = set(VALIDATE_ESTIMATOR_INIT) From 0dbb07ffc438d7a49d880bd4ef298b3c7abfa712 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 31 Mar 2022 20:33:48 -0400 Subject: [PATCH 020/102] Reformatting to comply with black=22.3.0 --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index ed15411b19279..6763e6361f868 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1745,7 +1745,7 @@ def _iter(self, *, fitted=False, replace_strings=False): parameter management: get_params_, set_params_, _estimators. """ if fitted: - for (name, estimator, cols) in self.estimators_: + for name, estimator, cols in self.estimators_: if replace_strings and _is_empty_column_selection(cols): continue else: From 5add60bed1aa9b62d45598a2ae2bc2365df1572c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 31 Mar 2022 21:30:14 -0400 Subject: [PATCH 021/102] Fixing the init and set_params test. Cf. #22537 --- sklearn/naive_bayes.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 6763e6361f868..05c37f8b6a5cb 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1956,15 +1956,22 @@ def _estimators(self): which expects lists of tuples of len 2. """ # Implemented in the image and likeness of ColumnTranformer._transformers - return [(name, e) for name, e, _ in self.estimatorNBs] + try: + return [(name, e) for name, e, _ in self.estimatorNBs] + except (TypeError, ValueError): # to pass init test in test_common.py + return self.estimatorNBs @_estimators.setter def _estimators(self, value): # Implemented in the image and likeness of ColumnTranformer._transformers # TODO: Is renaming or changing the order legal? Swap `name` and `_`? - self.estimatorNBs = [ - (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimatorNBs) - ] + try: + self.estimatorNBs = [ + (name, e, col) + for ((name, e), (_, _, col)) in zip(value, self.estimatorNBs) + ] + except (TypeError, ValueError): # to pass init test in test_common.py + self.estimatorNBs = value def get_params(self, deep=True): """Get parameters for this estimator. From f0fd8b3ef934451ffe637d42fb62af5d9389a304 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 31 Mar 2022 23:04:19 -0400 Subject: [PATCH 022/102] ColumnwiseNB: rename estimatorsNBs to nb_estimators --- sklearn/naive_bayes.py | 74 +++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 05c37f8b6a5cb..260bff802a0b9 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1475,8 +1475,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Parameters ---------- - estimatorNBs : list of tuples - List of (name, estimatorNB, columns) tuples specifying the naive Bayes + nb_estimators : list of tuples + List of (name, nb_estimator, columns) tuples specifying the naive Bayes estimators to be combined into a single naive Bayes meta-estimator. name : str @@ -1486,7 +1486,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): and :class:`~sklearn.compose.ColumnTransformer`, this allows the subestimator and its parameters to be set using :term:`set_params` and searched in grid search. - estimatorNB : estimator + nb_estimator : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the estimator must support ``_joint_log_likelihood`` method, which @@ -1498,7 +1498,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Indexes the data on its second axis. Integers are interpreted as positional columns, while strings can reference DataFrame columns by name. A scalar string or int should be used where - ``estimatorNB`` expects X to be a 1d array-like (vector), + ``nb_estimator`` expects X to be a 1d array-like (vector), otherwise a 2d array will be passed to the transformer. A callable is passed the input data `X` and can return any of the above. To select multiple columns by name or dtype, you can use @@ -1522,11 +1522,11 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Attributes ---------- estimators_ : list of tuples - List of ``(name, fitted_estimatorNB, columns)`` tuples, which follow - the order of `estimatorNBs`. Here, ``fitted_estimatorNB`` is a fitted naive + List of ``(name, fitted_nb_estimator, columns)`` tuples, which follow + the order of `nb_estimators`. Here, ``fitted_nb_estimator`` is a fitted naive Bayes estimator, except when ``columns`` presents an empty selection of - columns, in which case it is the original unfitted ``estimatorNB``. If - the original specification of ``columns`` in ``estimatorNBs`` was a + columns, in which case it is the original unfitted ``nb_estimator``. If + the original specification of ``columns`` in ``nb_estimators`` was a callable, then ``columns`` is converted to a list of column indices. named_estimators_ : :class:`~sklearn.utils.Bunch` @@ -1589,26 +1589,26 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): >>> X = rng.randint(5, size=(6, 100)) >>> y = np.array([0, 0, 1, 1, 2, 2]) >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB - >>> clf = ColumnwiseNB(estimatorNBs=[('mnb1', MultinomialNB(), [0, 1]), + >>> clf = ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), ... ('mnb2', MultinomialNB(), [3, 4]), ... ('gnb1', GaussianNB(), [5])]) >>> clf.fit(X, y) - ColumnwiseNB(estimatorNBs=[('mnb1', MultinomialNB(), [0, 1]), + ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), ('mnb2', MultinomialNB(), [3, 4]), ('gnb1', GaussianNB(), [5])]) >>> print(clf.predict(X)) [0 0 1 0 2 2] """ - _required_parameters = ["estimatorNBs"] + _required_parameters = ["nb_estimators"] def _log_message(self, name, idx, total): if not self.verbose: return None return "(%d of %d) Processing %s" % (idx, total, name) - def __init__(self, estimatorNBs, *, priors=None, n_jobs=None, verbose=False): - self.estimatorNBs = estimatorNBs + def __init__(self, nb_estimators, *, priors=None, n_jobs=None, verbose=False): + self.nb_estimators = nb_estimators self.priors = priors self.n_jobs = n_jobs self.verbose = verbose @@ -1627,13 +1627,13 @@ def _check_X(self, X): def _joint_log_likelihood(self, X): """Calculate the meta-estimator's joint log likelihood ``log P(x,c)``.""" # Because data must follow the same path as it would in subestimators, - # _jll_one(estimatorNB, X) passes it through estimatorNB._check_X to + # _jll_one(nb_estimator, X) passes it through nb_estimator._check_X to # match the implementation of _BaseNB.predict_log_proba. # Changes therein must be reflected in _jll_one or here. estimators = self._iter(fitted=True, replace_strings=True) all_jlls = Parallel(n_jobs=self.n_jobs)( - delayed(_jll_one)(estimator=estimatorNB, X=_safe_indexing(X, cols, axis=1)) - for (_, estimatorNB, cols) in estimators + delayed(_jll_one)(estimator=nb_estimator, X=_safe_indexing(X, cols, axis=1)) + for (_, nb_estimator, cols) in estimators ) n_estimators = len(all_jlls) log_prior = np.log(self.class_prior_) @@ -1642,12 +1642,12 @@ def _joint_log_likelihood(self, X): def _validate_estimators(self, check_partial=False): # Check if estimators have fit/partial_fit and jll methods # Validate estimator names via _BaseComposition._validate_names(self, names) - if not self.estimatorNBs: + if not self.nb_estimators: raise ValueError( "A list of naive Bayes estimators must be provided " - "in the form [(name, estimatorNB, columns), ... ]." + "in the form [(name, nb_estimator, columns), ... ]." ) - names, estimators, _ = zip(*self.estimatorNBs) + names, estimators, _ = zip(*self.nb_estimators) for e in estimators: if (not check_partial) and ( not (hasattr(e, "fit") and hasattr(e, "_joint_log_likelihood")) @@ -1679,7 +1679,7 @@ def _validate_column_callables(self, X): # ColumnTransformer code. all_columns = [] estimator_to_input_indices = {} - for name, _, columns in self.estimatorNBs: + for name, _, columns in self.nb_estimators: if callable(columns): columns = columns(X) all_columns.append(columns) @@ -1700,7 +1700,7 @@ def named_estimators_(self): return Bunch(**{name: e for name, e, _ in self.estimators_}) def _iter(self, *, fitted=False, replace_strings=False): - """Generate ``(name, estimatorNB, columns)`` tuples. + """Generate ``(name, nb_estimator, columns)`` tuples. This is a private method, similar to ColumnTransformer._iter. Must not be called before _validate_column_callables. @@ -1722,7 +1722,7 @@ def _iter(self, *, fitted=False, replace_strings=False): Yields ------ tuple - of the form ``(name, estimatorNB, columns)``. + of the form ``(name, nb_estimator, columns)``. Notes ----- @@ -1751,7 +1751,7 @@ def _iter(self, *, fitted=False, replace_strings=False): else: yield (name, estimator, cols) else: # fitted=False - for (name, estimator, _), cols in zip(self.estimatorNBs, self._columns): + for (name, estimator, _), cols in zip(self.nb_estimators, self._columns): if replace_strings and _is_empty_column_selection(cols): continue else: @@ -1795,18 +1795,18 @@ def _update_fitted_estimators(self, fitted_estimators): estimators_ = [] fitted_estimators = iter(fitted_estimators) - for name, estimatorNB, cols in self._iter(): + for name, nb_estimator, cols in self._iter(): if not _is_empty_column_selection(cols): - updated_estimatorNB = next(fitted_estimators) + updated_nb_estimator = next(fitted_estimators) else: # don't advance fitted_estimators; use original - updated_estimatorNB = estimatorNB - estimators_.append((name, updated_estimatorNB, cols)) + updated_nb_estimator = nb_estimator + estimators_.append((name, updated_nb_estimator, cols)) self.estimators_ = estimators_ def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. - Calls `fit` of each subestimator ``estimatorNB``. Only a corresponding + Calls `fit` of each subestimator ``nb_estimator``. Only a corresponding subset of columns of `X` is passed to each subestimator; `sample_weight` and `y` are passed to the subestimators as they are. @@ -1854,14 +1854,14 @@ def fit(self, X, y, sample_weight=None): estimators = list(self._iter(fitted=False, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( delayed(_fit_one)( - estimator=clone(estimatorNB), + estimator=clone(nb_estimator), X=_safe_indexing(X, cols, axis=1), y=y, message_clsname="ColumnwiseNB", message=self._log_message(name, idx, len(estimators)), sample_weight=sample_weight, ) - for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) ) self._update_fitted_estimators(fitted_estimators) self._update_class_prior() @@ -1934,7 +1934,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): estimators = list(self._iter(fitted=not first_call, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( delayed(_partial_fit_one)( - estimator=clone(estimatorNB) if first_call else estimatorNB, + estimator=clone(nb_estimator) if first_call else nb_estimator, X=_safe_indexing(X, cols, axis=1), y=y, message_clsname="ColumnwiseNB", @@ -1942,7 +1942,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): classes=classes, sample_weight=sample_weight, ) - for idx, (name, estimatorNB, cols) in enumerate(estimators, 1) + for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) ) self._update_fitted_estimators(fitted_estimators) self._update_class_prior() @@ -1957,21 +1957,21 @@ def _estimators(self): """ # Implemented in the image and likeness of ColumnTranformer._transformers try: - return [(name, e) for name, e, _ in self.estimatorNBs] + return [(name, e) for name, e, _ in self.nb_estimators] except (TypeError, ValueError): # to pass init test in test_common.py - return self.estimatorNBs + return self.nb_estimators @_estimators.setter def _estimators(self, value): # Implemented in the image and likeness of ColumnTranformer._transformers # TODO: Is renaming or changing the order legal? Swap `name` and `_`? try: - self.estimatorNBs = [ + self.nb_estimators = [ (name, e, col) - for ((name, e), (_, _, col)) in zip(value, self.estimatorNBs) + for ((name, e), (_, _, col)) in zip(value, self.nb_estimators) ] except (TypeError, ValueError): # to pass init test in test_common.py - self.estimatorNBs = value + self.nb_estimators = value def get_params(self, deep=True): """Get parameters for this estimator. From 1c884f39cd7d19cb4281672eecb543a5b1563ff0 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 1 Apr 2022 00:39:22 -0400 Subject: [PATCH 023/102] tests for ColumnwiseNB: rename estimatorsNBs to nb_estimators --- sklearn/tests/test_naive_bayes.py | 120 +++++++++++++++--------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index e79ffdfd29f3a..8d3e1d6309da7 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -954,7 +954,7 @@ def test_n_features_deprecation(Estimator): def test_cwnb_union(): # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] + nb_estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] ) clf2 = GaussianNB() clf1.fit(X, y) @@ -965,7 +965,7 @@ def test_cwnb_union(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) clf1 = ColumnwiseNB( - estimatorNBs=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.fit(X1, y1) @@ -977,7 +977,7 @@ def test_cwnb_union(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB # (partial_fit) clf1 = ColumnwiseNB( - estimatorNBs=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) @@ -989,7 +989,7 @@ def test_cwnb_union(): # A union of several different NB's is permutation-invariant clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [0]), ("m1", MultinomialNB(), [0, 2]), @@ -998,7 +998,7 @@ def test_cwnb_union(): ) # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [1]), ("m1", MultinomialNB(), [1, 0]), @@ -1023,10 +1023,10 @@ def test_cwnb_estimators_pandas(): # Subestimators spec: cols can be lists of int or lists of str, if DataFrame clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), ["col1"]), ("g2", GaussianNB(), ["col0", "col1"]), ] @@ -1045,14 +1045,14 @@ def test_cwnb_estimators_pandas(): # when callable columns produce the empty set. select_none = make_column_selector(pattern="qwerasdf") clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), select_none), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(Xdf, y) clf2.fit(Xdf, y) @@ -1060,10 +1060,10 @@ def test_cwnb_estimators_pandas(): clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 ) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 - assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 + assert len(clf1.nb_estimators) == len(clf1.estimators_) == 3 + assert len(clf2.nb_estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.nb_estimators[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: test callable columns select_int = make_column_selector(dtype_include=np.int_) @@ -1071,13 +1071,13 @@ def test_cwnb_estimators_pandas(): Xdf2 = Xdf Xdf2["col3"] = np.exp(Xdf["col0"]) - 0.5 * Xdf["col1"] clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), ["col3"]), ("m1", BernoulliNB(), ["col0", "col1"]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), select_float), ("g2", BernoulliNB(), select_int), ] @@ -1092,10 +1092,10 @@ def test_cwnb_estimators_pandas(): def test_cwnb_estimators_1(): # Subestimators spec: repeated col ints have the same effect as repeating data clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] + nb_estimators=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] ) clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] + nb_estimators=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] ) clf1.fit(X1, y1) clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) @@ -1105,34 +1105,34 @@ def test_cwnb_estimators_1(): # Subestimators spec: empty cols have the same effect as an absent estimator clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), []), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(X1, y1) clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.estimatorNBs) == len(clf1.estimators_) == 3 - assert len(clf2.estimatorNBs) == len(clf2.estimators_) == 2 + assert len(clf1.nb_estimators) == len(clf1.estimators_) == 3 + assert len(clf2.nb_estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.estimatorNBs[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.nb_estimators[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: error on repeated names clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] ) msg = "Names provided are not unique" with pytest.raises(ValueError, match=msg): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimatorNBs=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] + nb_estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] ) clf1.fit(X, y) @@ -1140,7 +1140,7 @@ def test_cwnb_estimators_1(): def test_cwnb_estimators_2(): # Subestimators spec: error on empty list clf = ColumnwiseNB( - estimatorNBs=[], + nb_estimators=[], ) msg = "A list of naive Bayes estimators must be provided*" with pytest.raises(ValueError, match=msg): @@ -1148,7 +1148,7 @@ def test_cwnb_estimators_2(): # Subestimators spec: error on None clf = ColumnwiseNB( - estimatorNBs=None, + nb_estimators=None, ) msg = "A list of naive Bayes estimators must be provided*" with pytest.raises(ValueError, match=msg): @@ -1169,7 +1169,7 @@ def partial_fit(self, X, y): def predict(self, X): pass - clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) @@ -1189,7 +1189,7 @@ def _joint_log_likelihood(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.fit(X, y) @@ -1209,30 +1209,30 @@ def _joint_log_likelihood(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(estimatorNBs=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) # _estimators setter works clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] + nb_estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] ) clf1.fit(X1, y1) clf1._estimators = [ ("x1", clf1.named_estimators_["g1"]), ("x2", clf1.named_estimators_["g1"]), ] - assert clf1.estimatorNBs[0][0] == "x1" - assert clf1.estimatorNBs[0][1] is clf1.named_estimators_["g1"] - assert clf1.estimatorNBs[1][0] == "x2" - assert clf1.estimatorNBs[1][1] is clf1.named_estimators_["g1"] + assert clf1.nb_estimators[0][0] == "x1" + assert clf1.nb_estimators[0][1] is clf1.named_estimators_["g1"] + assert clf1.nb_estimators[1][0] == "x2" + assert clf1.nb_estimators[1][1] is clf1.named_estimators_["g1"] def test_cwnb_prior(): # prior spec: error when negative, sum!=1 or bad length clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([-0.25, 1.25]), ) msg = "Priors must be non-negative." @@ -1240,7 +1240,7 @@ def test_cwnb_prior(): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.7]), ) msg = "The sum of the priors should be 1." @@ -1248,7 +1248,7 @@ def test_cwnb_prior(): clf1.fit(X, y) clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.25, 0.25, 0.25]), ) msg = "Number of priors must match number of classes." @@ -1258,15 +1258,15 @@ def test_cwnb_prior(): # prior spec: specified prior equals calculated and subestimators' priors # prior spec: str prior ties subestimators' clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.5, 0.5]), ) clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors="g1", ) clf3 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y) clf2.fit(X, y) @@ -1286,7 +1286,7 @@ def test_cwnb_prior(): def test_cwnb_zero_prior(): # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), ] @@ -1309,7 +1309,7 @@ def test_cwnb_zero_prior(): # P(y)=0 in two subestimators results in P(y|x)=0 of meta-estimator clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), ] @@ -1330,10 +1330,10 @@ def test_cwnb_sample_weight(): # weights in fit have no effect if all ones weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - estimatorNBs=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y, sample_weight=weights) clf2.fit(X, y) @@ -1345,13 +1345,13 @@ def test_cwnb_sample_weight(): # weights in partial_fit have no effect if all ones clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1369,13 +1369,13 @@ def test_cwnb_sample_weight(): idx = list(chain(*([i] * w for i, w in enumerate(weights)))) # var_smoothing=0.0 is for maximum precision in dealing with a small sample clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] @@ -1392,13 +1392,13 @@ def test_cwnb_sample_weight(): # weights in partial_fit have the same effect as repeating data clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1415,10 +1415,10 @@ def test_cwnb_sample_weight(): def test_cwnb_partial_fit(): # partial_fit: consecutive calls yield the same prediction as a single call clf1 = ColumnwiseNB( - estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf2 = ColumnwiseNB( - estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf1.partial_fit(X2, y2, classes=np.unique(y2)) clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) @@ -1433,7 +1433,7 @@ def test_cwnb_partial_fit(): # partial_fit: error when classes are not provided at the first call clf1 = ColumnwiseNB( - estimatorNBs=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) msg = ".lasses must be passed on the first call to partial_fit" with pytest.raises(ValueError, match=msg): @@ -1443,7 +1443,7 @@ def test_cwnb_partial_fit(): def test_cwnb_consistency(): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1463,7 +1463,7 @@ def test_cwnb_params(): # Can get and set subestimators' parameters through name__paramname # clone() works on ColumnwiseNB clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), ("m1", MultinomialNB(class_prior=[0.2, 0.2, 0.6]), [0, 2, 3]), ] @@ -1474,8 +1474,8 @@ def test_cwnb_params(): assert p["b1__binarize"] == 2 assert p["m1__class_prior"] == [0.2, 0.2, 0.6] clf1.set_params(b1__alpha=123, m1__class_prior=[0.3, 0.3, 0.4]) - assert clf1.estimatorNBs[0][1].alpha == 123 - assert_array_equal(clf1.estimatorNBs[1][1].class_prior, [0.3, 0.3, 0.4]) + assert clf1.nb_estimators[0][1].alpha == 123 + assert_array_equal(clf1.nb_estimators[1][1].class_prior, [0.3, 0.3, 0.4]) # After cloning and fitting, we can check through named_estimators, which # maps to fitted estimators_: clf2 = clone(clf1).fit(X2, y2) @@ -1487,7 +1487,7 @@ def test_cwnb_params(): def test_cwnb_n_jobs(): # n_jobs: same results wether with it or without clf1 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1496,7 +1496,7 @@ def test_cwnb_n_jobs(): n_jobs=4, ) clf2 = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1523,7 +1523,7 @@ def test_cwnb_example(): from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB clf = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("mnb1", MultinomialNB(), [0, 1]), ("mnb2", MultinomialNB(), [3, 4]), ("gnb1", GaussianNB(), [5]), @@ -1537,7 +1537,7 @@ def test_cwnb_verbose(capsys): # Setting verbose=True does not result in an error. # This DOES NOT test if the desired output is generated. clf = ColumnwiseNB( - estimatorNBs=[ + nb_estimators=[ ("mnb1", MultinomialNB(), [0, 1]), ("mnb2", MultinomialNB(), [3, 4]), ("gnb1", GaussianNB(), [5]), From 0dab95b805b0dcc9a09effda673e7efec3bd3dcb Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 1 Apr 2022 00:42:40 -0400 Subject: [PATCH 024/102] flake8 fix in text_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 8d3e1d6309da7..0fb984ebc0c4e 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1092,10 +1092,12 @@ def test_cwnb_estimators_pandas(): def test_cwnb_estimators_1(): # Subestimators spec: repeated col ints have the same effect as repeating data clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1])] + nb_estimators=[("g1", GaussianNB(), [1, 1]), + ("b1", BernoulliNB(), [0, 0, 1, 1])] ) clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5])] + nb_estimators=[("g1", GaussianNB(), [0, 1]), + ("b1", BernoulliNB(), [2, 3, 4, 5])] ) clf1.fit(X1, y1) clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) From 1bd1059a4584774d2d741c1413ca9c5b6413de40 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 1 Apr 2022 01:32:42 -0400 Subject: [PATCH 025/102] black fix in test_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 0fb984ebc0c4e..ef30c9cc4a810 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1092,12 +1092,16 @@ def test_cwnb_estimators_pandas(): def test_cwnb_estimators_1(): # Subestimators spec: repeated col ints have the same effect as repeating data clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1, 1]), - ("b1", BernoulliNB(), [0, 0, 1, 1])] + nb_estimators=[ + ("g1", GaussianNB(), [1, 1]), + ("b1", BernoulliNB(), [0, 0, 1, 1]), + ] ) clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [0, 1]), - ("b1", BernoulliNB(), [2, 3, 4, 5])] + nb_estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("b1", BernoulliNB(), [2, 3, 4, 5]), + ] ) clf1.fit(X1, y1) clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) From fcf032c0fc5f4dd25199176b71c8c59c9a4db295 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 9 Apr 2022 00:09:53 -0400 Subject: [PATCH 026/102] Added example: ColumnwiseNB for titanic dataset --- .../plot_combining_naive_bayes.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 examples/miscellaneous/plot_combining_naive_bayes.py diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py new file mode 100644 index 0000000000000..b2412fb513b1d --- /dev/null +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -0,0 +1,124 @@ +""" +=================================================== +Combining Naive Bayes Estimators using ColumnwiseNB +=================================================== + +.. currentmodule:: sklearn + +This example shows how to use :class:`~compose.ColumnTransformer` +meta-estimator to construct a naive Bayes model from base naive Bayes +estimators. The resulting model is applied to a dataset with a mixture of +discrete and continuous features. + +We consider the titanic dataset, in which: +- numerical (continous) features "age" and "fare" are handled by +:class:`~naive_bayes.GaussianNB`; +- categorical (discrete) features "embarked", "sex", and "pclass" are handled +by :class:`~naive_bayes.CategoricalNB`. +""" + +# Author: Andrey V. Melnik +# (based on a related work by Pedro Morales ) +# +# License: BSD 3 clause + +# %% +import pandas as pd +from sklearn.datasets import fetch_openml + +X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) +X["pclass"] = X["pclass"].astype("category") +# Add a category for NaNs to the "embarked" feature: +X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A") + +# +# Build and use a pipeline around ``ColumnwiseNB`` +# ------------------------------------------------ +# +# %% +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import OrdinalEncoder +from sklearn.naive_bayes import GaussianNB, CategoricalNB, ColumnwiseNB +from sklearn.model_selection import train_test_split, GridSearchCV +from sklearn.metrics import accuracy_score + +numeric_features = ["age", "fare"] +numeric_transformer = SimpleImputer(strategy="median") + +categorical_features = ["embarked", "sex", "pclass"] +categories = [X[c].unique().to_list() for c in X[categorical_features]] +categorical_transformer = OrdinalEncoder(categories=categories) + +preprocessor = ColumnTransformer( + transformers=[ + ("num", numeric_transformer, numeric_features), + ("cat", categorical_transformer, categorical_features), + ] +) + +classifier = ColumnwiseNB( + nb_estimators=[ + ("gnb", GaussianNB(), [0, 1]), + ("cnb", CategoricalNB(), [2, 3, 4]), + ] +) + +pipe = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", classifier)]) +pipe +# %% +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0) + +pipe.fit(X_train, y_train) +y_pred = pipe.predict(X_test) +print(f"Test accuracy: {accuracy_score(y_test, y_pred)}") + +# +# Compare choices of columns using ``GridSearchCV`` +# -------------------------------------------------- +# +# The allocation of columns to constituent subestimators can be regarded as a hyperparameter. +# We can explore the combinations of columns' choices and values of other hyperparameters +# with the help of :class:`~.model_selection.GridSearchCV`. +# %% +param_grid = { + "classifier__nb_estimators": [ + [("gnb", GaussianNB(), [0, 1]), ("cnb", CategoricalNB(), [2, 3, 4])], + [("gnb", GaussianNB(), []), ("cnb", CategoricalNB(), [3])], + [("gnb", GaussianNB(), [3]), ("cnb", CategoricalNB(), [])], + ], + "preprocessor__num__strategy": ["mean", "most_frequent"], +} + +grid_search = GridSearchCV(pipe, param_grid, cv=10) +grid_search + +# %% +# Calling `fit` triggers the cross-validated search for the best +# hyperparameters combination: +# +grid_search.fit(X_train, y_train) + +print("Best params:") +print(grid_search.best_params_) + +# It turns out, the best results are achieved by the naive Bayes model when "sex" +# is the only feature used: +# %% +cv_results = pd.DataFrame(grid_search.cv_results_) +cv_results = cv_results.sort_values("mean_test_score", ascending=False) +cv_results["Columns dictionary"] = cv_results["param_classifier__nb_estimators"].map( + lambda l: {e[0]: e[-1] for e in l} +) +cv_results["'gnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["gnb"]) +cv_results["'cnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["cnb"]) +cv_results[ + [ + "mean_test_score", + "std_test_score", + "param_preprocessor__num__strategy", + "'gnb' columns", + "'cnb' columns", + ] +] From e2e3ccf7d74f4a7642ad2e959c7e26c9b2738817 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 9 Apr 2022 00:31:00 -0400 Subject: [PATCH 027/102] flake8 fix --- examples/miscellaneous/plot_combining_naive_bayes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index b2412fb513b1d..318dbfe4f976a 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -78,9 +78,9 @@ # Compare choices of columns using ``GridSearchCV`` # -------------------------------------------------- # -# The allocation of columns to constituent subestimators can be regarded as a hyperparameter. -# We can explore the combinations of columns' choices and values of other hyperparameters -# with the help of :class:`~.model_selection.GridSearchCV`. +# The allocation of columns to constituent subestimators can be regarded as a +# hyperparameter. We can explore the combinations of columns' choices and values +# of other hyperparameters with the help of :class:`~.model_selection.GridSearchCV`. # %% param_grid = { "classifier__nb_estimators": [ From 75ce7af42c0ae891a3881d4395ca8b00743d7fb8 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 9 Apr 2022 17:17:34 -0400 Subject: [PATCH 028/102] ColumnwiseNB: added _check_n_features to fit and partial_fit --- sklearn/naive_bayes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 260bff802a0b9..f8a61766cb484 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1829,6 +1829,7 @@ def fit(self, X, y, sample_weight=None): # TODO: Consider overriding BaseEstimator._check_feature_names # Currently, when X has all str feature names, all features are # registered in self.feature_names_in no matter if they are used or not. + self._check_n_features(X, reset=True) self._validate_estimators() self._validate_column_callables(X) # Consistency checks for X, y are delegated to subestimators @@ -1900,10 +1901,12 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): first_call = not hasattr(self, "classes_") if first_call: self._check_feature_names(X, reset=True) + self._check_n_features(X, reset=True) self._validate_estimators(check_partial=True) self._validate_column_callables(X) else: self._check_feature_names(X, reset=False) + self._check_n_features(X, reset=False) # Consistency checks for X, y are delegated to subestimators # Subestimators get original sample_weight. This is for class counts: From 2d706bdd4625631596bd844c4b960540895520ee Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 9 Apr 2022 21:31:18 -0400 Subject: [PATCH 029/102] Correction to the example (ColumnwiseNB for titanic dataset) --- examples/miscellaneous/plot_combining_naive_bayes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index 318dbfe4f976a..a6b799b86db96 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -5,15 +5,16 @@ .. currentmodule:: sklearn -This example shows how to use :class:`~compose.ColumnTransformer` +This example shows how to use :class:`~naive_bayes.ColumnwiseNB` meta-estimator to construct a naive Bayes model from base naive Bayes estimators. The resulting model is applied to a dataset with a mixture of discrete and continuous features. We consider the titanic dataset, in which: -- numerical (continous) features "age" and "fare" are handled by + + - numerical (continous) features "age" and "fare" are handled by :class:`~naive_bayes.GaussianNB`; -- categorical (discrete) features "embarked", "sex", and "pclass" are handled + - categorical (discrete) features "embarked", "sex", and "pclass" are handled by :class:`~naive_bayes.CategoricalNB`. """ From 4203f7b80ffe603a994f4d13d86fa65e6240f037 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 04:13:19 -0400 Subject: [PATCH 030/102] Added a section to the naive bayes guide in documentation --- doc/modules/naive_bayes.rst | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index b2dd4cf5a7cd3..96348815c8e55 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -281,3 +281,37 @@ For an overview of available strategies in scikit-learn, see also the The ``partial_fit`` method call of naive Bayes models introduces some computational overhead. It is recommended to use data chunk sizes that are as large as possible, that is as the available RAM allows. + +.. _columnwise_naive_bayes: + +Mix and match naive Bayes models +-------------------------------- + +A naive Bayes model that assumes different distribution families for different +features (or subsets of features) can be constructed using :class:`ColumnwiseNB`. +It is a meta-estimator, whose operation relies on naive Bayes +sub-estimators, which can be chosen in any number of combination from +:class:`GaussianNB`, :class:`MultinomialNB`, :class:`ComplementNB`, +:class:`BernoulliNB`, :class:`CategoricalNB`, and user-defined models +(provided they expose the necessary methods). + +To initialize :class:`ColumnwiseNB`, one must pass a list of tuples specifying +sub-estimators and their respective column subsets. +Each sub-estimator is fitted and evaluated independently of +others and "sees" only the features assigned to it. The estimators' predictions are +combined via + +.. math:: + + \log P(x,y)=\log P(x_{1},y) + \dots + \log P(x_{M},y) - (M - 1)\log P(y), + +where :math:`\log P(x,y)` is the joint log-likelihood of the meta-estimator, +:math:`\log P(x_{m},y)` is that of the :math:`m` th sub-estimator, +:math:`\log P(y)` is the class prior used by the meta-estimator, and +:math:`M\geq1` is the total number of sub-estimators. + +See :ref:`sphx_glr_auto_examples_miscellaneous_plot_combining_naive_bayes.py` +for an example of a mixed naive Bayes model. +See also :ref:`voting_classifier` for a way of combining general classifiers. +For an introduction to processing datasets with heterogeneous features, +see :ref:`column_transformer`. From ffc4b60f27f7b4bcd1bb94b9a09949fbf6c106f1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 09:35:09 -0400 Subject: [PATCH 031/102] CI fix: try n_retires=10 in fetch_openml --- examples/miscellaneous/plot_combining_naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index a6b799b86db96..5153775b387eb 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -27,7 +27,7 @@ import pandas as pd from sklearn.datasets import fetch_openml -X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) +X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True, n_retries=10) X["pclass"] = X["pclass"].astype("category") # Add a category for NaNs to the "embarked" feature: X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A") From 31b47a035171eb20c47b0fa2937478ec444e0851 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 10:04:41 -0400 Subject: [PATCH 032/102] Fix formatting in the gallery example --- examples/miscellaneous/plot_combining_naive_bayes.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index 5153775b387eb..c01ee6e248fc6 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -12,10 +12,8 @@ We consider the titanic dataset, in which: - - numerical (continous) features "age" and "fare" are handled by -:class:`~naive_bayes.GaussianNB`; - - categorical (discrete) features "embarked", "sex", and "pclass" are handled -by :class:`~naive_bayes.CategoricalNB`. + - numerical (continous) features "age" and "fare" are handled by :class:`~naive_bayes.GaussianNB`; + - categorical (discrete) features "embarked", "sex", and "pclass" are handled by :class:`~naive_bayes.CategoricalNB`. """ # Author: Andrey V. Melnik From 64adf79c100d713f34e2c9eae0ce7debb730608b Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 10:06:14 -0400 Subject: [PATCH 033/102] Fix formatting in the gallery example --- examples/miscellaneous/plot_combining_naive_bayes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index c01ee6e248fc6..09867cc3d85d9 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -12,8 +12,10 @@ We consider the titanic dataset, in which: - - numerical (continous) features "age" and "fare" are handled by :class:`~naive_bayes.GaussianNB`; - - categorical (discrete) features "embarked", "sex", and "pclass" are handled by :class:`~naive_bayes.CategoricalNB`. + - numerical (continous) features "age" and "fare" are handled by + :class:`~naive_bayes.GaussianNB`; + - categorical (discrete) features "embarked", "sex", and "pclass" are handled + by :class:`~naive_bayes.CategoricalNB`. """ # Author: Andrey V. Melnik From f6eaab763680e528ef5fdcfb8b93e8badea3c00b Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 12:56:12 -0400 Subject: [PATCH 034/102] Improve documentation and the gallery example --- doc/modules/naive_bayes.rst | 19 ++++++++------- .../plot_combining_naive_bayes.py | 24 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 96348815c8e55..3fc7e039027b7 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -290,15 +290,15 @@ Mix and match naive Bayes models A naive Bayes model that assumes different distribution families for different features (or subsets of features) can be constructed using :class:`ColumnwiseNB`. It is a meta-estimator, whose operation relies on naive Bayes -sub-estimators, which can be chosen in any number of combination from +sub-estimators, which can be chosen in any number or combination from :class:`GaussianNB`, :class:`MultinomialNB`, :class:`ComplementNB`, :class:`BernoulliNB`, :class:`CategoricalNB`, and user-defined models -(provided they expose the necessary methods). +(provided they implement necessary methods). -To initialize :class:`ColumnwiseNB`, one must pass a list of tuples specifying -sub-estimators and their respective column subsets. -Each sub-estimator is fitted and evaluated independently of -others and "sees" only the features assigned to it. The estimators' predictions are +When creating a :class:`ColumnwiseNB` estimator, one specifies sub-estimators +and their respective column subsets as a list of tuples. +Each sub-estimator is fitted and evaluated independently of the +others and "sees" only the features assigned to it. The predictions of sub-estimators are combined via .. math:: @@ -311,7 +311,8 @@ where :math:`\log P(x,y)` is the joint log-likelihood of the meta-estimator, :math:`M\geq1` is the total number of sub-estimators. See :ref:`sphx_glr_auto_examples_miscellaneous_plot_combining_naive_bayes.py` -for an example of a mixed naive Bayes model. +for an example of a mixed naive Bayes model implementation. + See also :ref:`voting_classifier` for a way of combining general classifiers. -For an introduction to processing datasets with heterogeneous features, -see :ref:`column_transformer`. +An introduction to processing datasets with heterogeneous features is available at +:ref:`column_transformer`. diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index 09867cc3d85d9..be4b30f6248e7 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -12,10 +12,10 @@ We consider the titanic dataset, in which: - - numerical (continous) features "age" and "fare" are handled by - :class:`~naive_bayes.GaussianNB`; - - categorical (discrete) features "embarked", "sex", and "pclass" are handled - by :class:`~naive_bayes.CategoricalNB`. +- numerical (continous) features "age" and "fare" are handled by + :class:`~naive_bayes.GaussianNB`; +- categorical (discrete) features "embarked", "sex", and "pclass" are handled + by :class:`~naive_bayes.CategoricalNB`. """ # Author: Andrey V. Melnik @@ -32,11 +32,10 @@ # Add a category for NaNs to the "embarked" feature: X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A") -# +# %% # Build and use a pipeline around ``ColumnwiseNB`` # ------------------------------------------------ -# -# %% + from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer @@ -75,14 +74,14 @@ y_pred = pipe.predict(X_test) print(f"Test accuracy: {accuracy_score(y_test, y_pred)}") -# +# %% # Compare choices of columns using ``GridSearchCV`` # -------------------------------------------------- # # The allocation of columns to constituent subestimators can be regarded as a # hyperparameter. We can explore the combinations of columns' choices and values # of other hyperparameters with the help of :class:`~.model_selection.GridSearchCV`. -# %% + param_grid = { "classifier__nb_estimators": [ [("gnb", GaussianNB(), [0, 1]), ("cnb", CategoricalNB(), [2, 3, 4])], @@ -98,15 +97,16 @@ # %% # Calling `fit` triggers the cross-validated search for the best # hyperparameters combination: -# + grid_search.fit(X_train, y_train) print("Best params:") print(grid_search.best_params_) -# It turns out, the best results are achieved by the naive Bayes model when "sex" -# is the only feature used: # %% +# As it turns out, the best results are achieved by the naive Bayes model when "sex" +# is the only feature used: + cv_results = pd.DataFrame(grid_search.cv_results_) cv_results = cv_results.sort_values("mean_test_score", ascending=False) cv_results["Columns dictionary"] = cv_results["param_classifier__nb_estimators"].map( From a40fdcaa72d13eb7b6dc227e8db29e941d594377 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 20:12:04 -0400 Subject: [PATCH 035/102] Re #21355 'no validation at init'-test: Remove the logic at setter, keep at getter --- sklearn/naive_bayes.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index f8a61766cb484..b2ff55aa07e0d 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1961,20 +1961,20 @@ def _estimators(self): # Implemented in the image and likeness of ColumnTranformer._transformers try: return [(name, e) for name, e, _ in self.nb_estimators] - except (TypeError, ValueError): # to pass init test in test_common.py + except (TypeError, ValueError): + # This try-except clause is needed to pass the test from test_common.py: + # test_estimators_do_not_raise_errors_in_init_or_set_params(). + # ColumnTransformer does the same. See PR #21355 for details. return self.nb_estimators @_estimators.setter def _estimators(self, value): # Implemented in the image and likeness of ColumnTranformer._transformers # TODO: Is renaming or changing the order legal? Swap `name` and `_`? - try: - self.nb_estimators = [ - (name, e, col) - for ((name, e), (_, _, col)) in zip(value, self.nb_estimators) - ] - except (TypeError, ValueError): # to pass init test in test_common.py - self.nb_estimators = value + self.nb_estimators = [ + (name, e, col) + for ((name, e), (_, _, col)) in zip(value, self.nb_estimators) + ] def get_params(self, deep=True): """Get parameters for this estimator. From 40085e5ba2a7f759d7b65b17321a0b21ff9fe672 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 22:25:53 -0400 Subject: [PATCH 036/102] Add test for error when subestimator does not compute class priors, but is expected to --- sklearn/tests/test_naive_bayes.py | 45 +++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index b8b80061145c6..5e5de18781d3b 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1291,6 +1291,51 @@ def test_cwnb_prior(): assert_array_almost_equal(clf1.class_prior_, clf3.class_prior_, 8) assert_array_equal(clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_) + # prior spec: error message when can't extract prior from subestimator + class GaussianNB_hide_prior(GaussianNB): + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=None) + self.qwerqwer = self.class_prior_ + del self.class_prior_ + + def _joint_log_likelihood(self, X): + self.class_prior_ = self.qwerqwer + super()._joint_log_likelihood(X) + del self.class_prior_ + + class MultinomialNB_hide_log_prior(MultinomialNB): + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=None) + self.qwerqwer = self.class_log_prior_ + del self.class_log_prior_ + + def _joint_log_likelihood(self, X): + self.class_log_prior_ = self.qwerqwer + super()._joint_log_likelihood(X) + del self.class_log_prior_ + + clf = ColumnwiseNB( + nb_estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB_hide_prior(), [0, 1]), + ], + priors="g2", + ) + msg = "Unable to extract class prior from estimator g2*" + with pytest.raises(AttributeError, match=msg): + clf.fit(X, y) + + clf = ColumnwiseNB( + nb_estimators=[ + ("g1", GaussianNB(), [0]), + ("m1", MultinomialNB_hide_log_prior(), [1, 2, 3, 4, 5]), + ], + priors="m1", + ) + msg = "Unable to extract class prior from estimator m1*" + with pytest.raises(AttributeError, match=msg): + clf.fit(X2, y2) + def test_cwnb_zero_prior(): # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator From 019c0e01b430a50567dec820370d92394a1c872c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 10 Apr 2022 22:42:04 -0400 Subject: [PATCH 037/102] Extend the test for class priors extraction to MultinomialNB to cover class_log_prior case --- sklearn/tests/test_naive_bayes.py | 39 ++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 5e5de18781d3b..54f70114169b0 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1267,29 +1267,46 @@ def test_cwnb_prior(): # prior spec: specified prior equals calculated and subestimators' priors # prior spec: str prior ties subestimators' clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], - priors=np.array([0.5, 0.5]), + nb_estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], + priors=np.array([1 / 3, 1 / 3, 1 / 3]), ) - clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + clf2a = ColumnwiseNB( + nb_estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], priors="g1", ) + clf2b = ColumnwiseNB( + nb_estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], + priors="m1", + ) clf3 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + nb_estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], ) - clf1.fit(X, y) - clf2.fit(X, y) - clf3.fit(X, y) + clf1.fit(X2, y2) + clf2a.fit(X2, y2) + clf2b.fit(X2, y2) + clf3.fit(X2, y2) assert clf3.priors is None assert_array_almost_equal( clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_, 8 ) assert_array_almost_equal( - clf1.class_prior_, clf1.named_estimators_["g2"].class_prior_, 8 + np.log(clf1.class_prior_), clf1.named_estimators_["m1"].class_log_prior_, 8 ) - assert_array_almost_equal(clf1.class_prior_, clf2.class_prior_, 8) + assert_array_almost_equal(clf1.class_prior_, clf2a.class_prior_, 8) + assert_array_almost_equal(clf1.class_prior_, clf2b.class_prior_, 8) assert_array_almost_equal(clf1.class_prior_, clf3.class_prior_, 8) - assert_array_equal(clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_) # prior spec: error message when can't extract prior from subestimator class GaussianNB_hide_prior(GaussianNB): From 197740db21ad103a128489b862162153254b9583 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Apr 2022 00:32:44 -0400 Subject: [PATCH 038/102] Add ColumnwiseNB._sk_visual_block_ method for better HTML representation --- sklearn/naive_bayes.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index b2ff55aa07e0d..f8cb0b457d74a 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -39,6 +39,7 @@ from .utils import _print_elapsed_time from .utils import Bunch from .utils.fixes import delayed +from .utils._estimator_html_repr import _VisualBlock from .compose._column_transformer import _is_empty_column_selection @@ -2017,3 +2018,10 @@ def set_params(self, **kwargs): # Implemented in the image and likeness of ColumnTranformer.set_params self._set_params("_estimators", **kwargs) return self + + def _sk_visual_block_(self): + """HTML representation of this estimator.""" + names, estimators, name_details = zip(*self.nb_estimators) + return _VisualBlock( + "parallel", estimators, names=names, name_details=name_details + ) From b27599326eda726ad777511933d4379f24b80340 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Apr 2022 01:11:19 -0400 Subject: [PATCH 039/102] Add test for ColumnwiseNB._sk_visual_block() --- sklearn/tests/test_naive_bayes.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 54f70114169b0..ab124323f820f 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1618,3 +1618,22 @@ def test_cwnb_verbose(capsys): ) clf.fit(X2, y2) clf.predict(X2) + + +def test_cwnb_sk_visual_block(capsys): + # Setting verbose=True does not result in an error. + # This DOES NOT test if the desired output is generated. + estimators = (MultinomialNB(), MultinomialNB(), GaussianNB()) + clf = ColumnwiseNB( + nb_estimators=[ + ("mnb1", estimators[0], [0, 1]), + ("mnb2", estimators[1], [3, 4]), + ("gnb1", estimators[2], [5]), + ], + verbose=True, + n_jobs=4, + ) + visual_block = clf._sk_visual_block_() + assert visual_block.names == ('mnb1', 'mnb2', 'gnb1') + assert visual_block.name_details == ([0, 1], [3, 4], [5]) + assert visual_block.estimators == estimators From 5a8d59eb52b8d4d481ee2ce7d6dcf57699144543 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Apr 2022 01:12:54 -0400 Subject: [PATCH 040/102] Black formatting correction --- sklearn/tests/test_naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index ab124323f820f..080d81b8f80c4 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1634,6 +1634,6 @@ def test_cwnb_sk_visual_block(capsys): n_jobs=4, ) visual_block = clf._sk_visual_block_() - assert visual_block.names == ('mnb1', 'mnb2', 'gnb1') + assert visual_block.names == ("mnb1", "mnb2", "gnb1") assert visual_block.name_details == ([0, 1], [3, 4], [5]) assert visual_block.estimators == estimators From be835129ea7e93d54fbd16a34a7e5609451c63be Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Apr 2022 03:38:09 -0400 Subject: [PATCH 041/102] Tests for ColumnwiseNB priors: Remove unnecessary definitions --- sklearn/tests/test_naive_bayes.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 080d81b8f80c4..65e95f93e6c9f 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1315,22 +1315,12 @@ def fit(self, X, y, sample_weight=None): self.qwerqwer = self.class_prior_ del self.class_prior_ - def _joint_log_likelihood(self, X): - self.class_prior_ = self.qwerqwer - super()._joint_log_likelihood(X) - del self.class_prior_ - class MultinomialNB_hide_log_prior(MultinomialNB): def fit(self, X, y, sample_weight=None): super().fit(X, y, sample_weight=None) self.qwerqwer = self.class_log_prior_ del self.class_log_prior_ - def _joint_log_likelihood(self, X): - self.class_log_prior_ = self.qwerqwer - super()._joint_log_likelihood(X) - del self.class_log_prior_ - clf = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1]), From b93f39d0d0aa4e51539e828025398db0934839be Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 12 Apr 2022 14:04:35 -0400 Subject: [PATCH 042/102] Change log: add an entry for ColumnwiseNB --- doc/whats_new/v1.1.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 8f6f2e2ad7cb7..f06ced1c11b69 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -857,6 +857,13 @@ Changelog - |Fix| :meth:`multiclass.OneVsOneClassifier.predict` returns correct predictions when the inner classifier only has a :term:`predict_proba`. :pr:`22604` by `Thomas Fan`_. +:mod:`sklearn.naive_bayes` +.......................... + +- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows + existing naive Bayes classifiers to be combined and applied to different columns + of `X`. :pr:`22574` by :user:`Andrey Melnik `. + :mod:`sklearn.neighbors` ........................ From e423db915b85148d49603afd2e541d257806e999 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 1 May 2022 01:31:30 -0400 Subject: [PATCH 043/102] Changelog entry moved from v1.1.rst to v1.2.rst --- doc/whats_new/v1.1.rst | 7 ------- doc/whats_new/v1.2.rst | 7 +++++++ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index b04a039a3831c..0463ae35f3052 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -920,13 +920,6 @@ Changelog - |Fix| :meth:`multiclass.OneVsOneClassifier.predict` returns correct predictions when the inner classifier only has a :term:`predict_proba`. :pr:`22604` by `Thomas Fan`_. -:mod:`sklearn.naive_bayes` -.......................... - -- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows - existing naive Bayes classifiers to be combined and applied to different columns - of `X`. :pr:`22574` by :user:`Andrey Melnik `. - :mod:`sklearn.neighbors` ........................ diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index bdb9f3018aba8..69d8099b6e952 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -33,6 +33,13 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.naive_bayes` +.......................... + +- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows + existing naive Bayes classifiers to be combined and applied to different columns + of `X`. :pr:`22574` by :user:`Andrey Melnik `. + Code and Documentation Contributors ----------------------------------- From 261bc5d1a3a1d641028dac13b1c5e79998accb63 Mon Sep 17 00:00:00 2001 From: avm19 Date: Sat, 4 Jun 2022 15:28:52 -0400 Subject: [PATCH 044/102] Format cited code in docstring sklearn/naive_bayes.py Co-authored-by: Alexandre Gramfort --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index f8cb0b457d74a..6cda0ff42d37f 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1591,8 +1591,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): >>> y = np.array([0, 0, 1, 1, 2, 2]) >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB >>> clf = ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), - ... ('mnb2', MultinomialNB(), [3, 4]), - ... ('gnb1', GaussianNB(), [5])]) + ... ('mnb2', MultinomialNB(), [3, 4]), + ... ('gnb1', GaussianNB(), [5])]) >>> clf.fit(X, y) ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), ('mnb2', MultinomialNB(), [3, 4]), From c19e76b91c6015f979e5e0def5a9f9626bd187c1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 4 Jun 2022 15:39:24 -0400 Subject: [PATCH 045/102] Remove unnecessary import in test.naive_bayes.py::test_cwnb_example https://github.com/scikit-learn/scikit-learn/pull/22574#discussion_r889539665 --- sklearn/tests/test_naive_bayes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 65e95f93e6c9f..6117b54cedee1 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1576,12 +1576,9 @@ def test_cwnb_n_jobs(): def test_cwnb_example(): # Test the Example from ColumnwiseNB docstring in naive_bayes.py - import numpy as np - rng = np.random.RandomState(1) X = rng.randint(5, size=(6, 100)) y = np.array([0, 0, 1, 1, 2, 2]) - from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB clf = ColumnwiseNB( nb_estimators=[ From 1eb5d499fd3f5abb508c51b9a4dccb6cd50d9603 Mon Sep 17 00:00:00 2001 From: avm19 Date: Sat, 4 Jun 2022 15:56:01 -0400 Subject: [PATCH 046/102] Update authors in examples/miscellaneous/plot_combining_naive_bayes.py Co-authored-by: Alexandre Gramfort --- examples/miscellaneous/plot_combining_naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index be4b30f6248e7..3f0845492fb51 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -19,7 +19,7 @@ """ # Author: Andrey V. Melnik -# (based on a related work by Pedro Morales ) +# Pedro Morales # # License: BSD 3 clause From 7a8ee1871f6b8c31835108f36c5b8652233c9dc3 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 4 Jun 2022 17:44:38 -0400 Subject: [PATCH 047/102] Split test functions and give better names in test_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 106 ++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 6117b54cedee1..87b4f5f7decce 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -954,7 +954,7 @@ def test_n_features_deprecation(Estimator): est.n_features_ -def test_cwnb_union(): +def test_cwnb_union_gnb_fit(): # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] @@ -966,6 +966,8 @@ def test_cwnb_union(): assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + +def test_cwnb_union_bnb_fit(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] @@ -977,6 +979,8 @@ def test_cwnb_union(): assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + +def test_cwnb_union_bnb_partial_fit(): # A union of BernoulliNB's yields the same prediction a single BernoulliNB # (partial_fit) clf1 = ColumnwiseNB( @@ -990,6 +994,8 @@ def test_cwnb_union(): assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + +def test_cwnb_union_permutation(): # A union of several different NB's is permutation-invariant clf1 = ColumnwiseNB( nb_estimators=[ @@ -1092,7 +1098,7 @@ def test_cwnb_estimators_pandas(): ) -def test_cwnb_estimators_1(): +def test_cwnb_repeated_columns(): # Subestimators spec: repeated col ints have the same effect as repeating data clf1 = ColumnwiseNB( nb_estimators=[ @@ -1112,6 +1118,8 @@ def test_cwnb_estimators_1(): clf1.predict_log_proba(X1), clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8 ) + +def test_cwnb_empty_columns(): # Subestimators spec: empty cols have the same effect as an absent estimator clf1 = ColumnwiseNB( nb_estimators=[ @@ -1132,6 +1140,8 @@ def test_cwnb_estimators_1(): # No cloning of the empty-columns estimators took place: assert id(clf1.nb_estimators[1][1]) == id(clf1.named_estimators_["g2"]) + +def test_cwnb_estimators_unique_names(): # Subestimators spec: error on repeated names clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] @@ -1146,7 +1156,7 @@ def test_cwnb_estimators_1(): clf1.fit(X, y) -def test_cwnb_estimators_2(): +def test_cwnb_estimators_nonempty_list(): # Subestimators spec: error on empty list clf = ColumnwiseNB( nb_estimators=[], @@ -1163,6 +1173,8 @@ def test_cwnb_estimators_2(): with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) + +def test_cwnb_estimators_support_jll(): # Subestimators spec: error when some don't support _joint_log_likelihood class notNB(BaseEstimator): def __init__(self): @@ -1183,6 +1195,8 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) + +def test_cwnb_estimators_support_fit(): # Subestimators spec: error when some don't support fit class notNB(BaseEstimator): def __init__(self): @@ -1203,6 +1217,8 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.fit(X, y) + +def test_cwnb_estimators_support_partial_fit(): # Subestimators spec: error when some don't support partial_fit class notNB(BaseEstimator): def __init__(self): @@ -1223,6 +1239,8 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) + +def test_cwnb_estimators_setter(): # _estimators setter works clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] @@ -1238,7 +1256,7 @@ def predict(self, X): assert clf1.nb_estimators[1][1] is clf1.named_estimators_["g1"] -def test_cwnb_prior(): +def test_cwnb_prior_valid_spec(): # prior spec: error when negative, sum!=1 or bad length clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], @@ -1264,34 +1282,39 @@ def test_cwnb_prior(): with pytest.raises(ValueError, match=msg): clf1.fit(X, y) - # prior spec: specified prior equals calculated and subestimators' priors - # prior spec: str prior ties subestimators' + +def test_cwnb_prior_match(): + # prior spec: all these ways work (and agree in our example) + # (1) an array of values + # (2a) a str name of a subestimator supporting class_prior_ + # (2b) a str name of a subestimator supporting class_log_prior_ + # (3) nothing (ColumnwiseNB will calculate relative frequencies) clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], - priors=np.array([1 / 3, 1 / 3, 1 / 3]), + priors=np.array([1 / 3, 1 / 3, 1 / 3]), # prior is provided by user ) clf2a = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], - priors="g1", + priors="g1", # prior will be estimated by sub-estimator "g1" ) clf2b = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], - priors="m1", + priors="m1", # prior will be estimated by sub-estimator "m1" ) clf3 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), - ], + ], # prior will be estimated by the meta-estimator ) clf1.fit(X2, y2) clf2a.fit(X2, y2) @@ -1308,19 +1331,17 @@ def test_cwnb_prior(): assert_array_almost_equal(clf1.class_prior_, clf2b.class_prior_, 8) assert_array_almost_equal(clf1.class_prior_, clf3.class_prior_, 8) + +def test_cwnb_estimators_support_class_prior_gnb(): # prior spec: error message when can't extract prior from subestimator + # ColumnwiseNB tries both class_prior_ and class_log_prior, which is tested + # in test_cwnb_prior_match() class GaussianNB_hide_prior(GaussianNB): def fit(self, X, y, sample_weight=None): super().fit(X, y, sample_weight=None) self.qwerqwer = self.class_prior_ del self.class_prior_ - class MultinomialNB_hide_log_prior(MultinomialNB): - def fit(self, X, y, sample_weight=None): - super().fit(X, y, sample_weight=None) - self.qwerqwer = self.class_log_prior_ - del self.class_log_prior_ - clf = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1]), @@ -1332,6 +1353,17 @@ def fit(self, X, y, sample_weight=None): with pytest.raises(AttributeError, match=msg): clf.fit(X, y) + +def test_cwnb_estimators_support_class_prior_mnb(): + # prior spec: error message when can't extract prior from subestimator + # ColumnwiseNB tries both class_prior_ and class_log_prior, which is tested + # in test_cwnb_prior_match() + class MultinomialNB_hide_log_prior(MultinomialNB): + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=None) + self.qwerqwer = self.class_log_prior_ + del self.class_log_prior_ + clf = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0]), @@ -1344,8 +1376,13 @@ def fit(self, X, y, sample_weight=None): clf.fit(X2, y2) -def test_cwnb_zero_prior(): - # P(y)=0 in a subestimator results in P(y|x)=0 of meta-estimator +def test_cwnb_prior_nonzero(): + # P(y)=0 in one or two subestimators results in P(y|x)=0 of meta-estimator. + # Despite attempted Log[0], predicted class probabilities are all finite. + # On a related note, meaningless results (including NaNs) may be produced + # - if P(y)=0 in the meta-estimator, or/and + # - if class priors differ across subestimators, + # but this is not what is tested here. clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1, 3, 5]), @@ -1364,30 +1401,25 @@ def test_cwnb_zero_prior(): assert_almost_equal(np.abs(p).sum(), 0) assert np.isfinite(p).all() - # P(y)=0 in the meta-estimator, as well as class priors that differ across - # subestimators may produce meaningless results, including NaNs. This case - # is not tested here. - - # P(y)=0 in two subestimators results in P(y|x)=0 of meta-estimator clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), - ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), + ("g2", GaussianNB(priors=np.array([0.5, 0.5, 0])), [0, 1]), ] ) clf1.fit(X2, y2) with pytest.warns(RuntimeWarning, match=msg): - p = clf1.predict_proba(X2)[:, 1] + p = clf1.predict_proba(X2)[:, 1:] assert_almost_equal(np.abs(p).sum(), 0) assert np.isfinite(p).all() Xt = rng.randint(5, size=(6, 100)) with pytest.warns(RuntimeWarning, match=msg): - p = clf1.predict_proba(Xt)[:, 1] + p = clf1.predict_proba(Xt)[:, 1:] assert_almost_equal(np.abs(p).sum(), 0) assert np.isfinite(p).all() -def test_cwnb_sample_weight(): +def test_cwnb_fit_sample_weight_ones(): # weights in fit have no effect if all ones weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( @@ -1404,7 +1436,10 @@ def test_cwnb_sample_weight(): assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) assert_array_equal(clf1.predict(X), clf2.predict(X)) + +def test_cwnb_partial_fit_sample_weight_ones(): # weights in partial_fit have no effect if all ones + weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), @@ -1425,6 +1460,8 @@ def test_cwnb_sample_weight(): assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + +def test_cwnb_fit_sample_weight_repeated(): # weights in fit have the same effect as repeating data weights = [1, 2, 3, 1, 4, 2] idx = list(chain(*([i] * w for i, w in enumerate(weights)))) @@ -1451,7 +1488,11 @@ def test_cwnb_sample_weight(): for attr_name in ("class_count_", "class_prior_", "classes_"): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + +def test_cwnb_partial_fit_sample_weight_repeated(): # weights in partial_fit have the same effect as repeating data + weights = [1, 2, 3, 1, 4, 2] + idx = list(chain(*([i] * w for i, w in enumerate(weights)))) clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), @@ -1492,6 +1533,8 @@ def test_cwnb_partial_fit(): for attr_name in ("class_count_", "class_prior_", "classes_"): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + +def test_cwnb_partial_fit_classes(): # partial_fit: error when classes are not provided at the first call clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] @@ -1501,7 +1544,7 @@ def test_cwnb_partial_fit(): clf1.partial_fit(X2, y2) -def test_cwnb_consistency(): +def test_cwnb_class_attributes_consistency(): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators clf1 = ColumnwiseNB( nb_estimators=[ @@ -1546,7 +1589,7 @@ def test_cwnb_params(): def test_cwnb_n_jobs(): - # n_jobs: same results wether with it or without + # n_jobs: same result whether with it or without clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), @@ -1608,8 +1651,7 @@ def test_cwnb_verbose(capsys): def test_cwnb_sk_visual_block(capsys): - # Setting verbose=True does not result in an error. - # This DOES NOT test if the desired output is generated. + # visual block representation correctly extracts names, cols and estimators estimators = (MultinomialNB(), MultinomialNB(), GaussianNB()) clf = ColumnwiseNB( nb_estimators=[ @@ -1617,8 +1659,6 @@ def test_cwnb_sk_visual_block(capsys): ("mnb2", estimators[1], [3, 4]), ("gnb1", estimators[2], [5]), ], - verbose=True, - n_jobs=4, ) visual_block = clf._sk_visual_block_() assert visual_block.names == ("mnb1", "mnb2", "gnb1") From 03e48e644684ced0804b20c3f336f59e2ec073e7 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 4 Jun 2022 23:30:29 -0400 Subject: [PATCH 048/102] Namechange and minor comments in test_naive_bayes.py --- sklearn/tests/test_naive_bayes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 87b4f5f7decce..229fb1c904d4c 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -954,8 +954,8 @@ def test_n_features_deprecation(Estimator): est.n_features_ -def test_cwnb_union_gnb_fit(): - # A union of GaussianNB's yields the same prediction a single GaussianNB (fit) +def test_cwnb_union_gnb(): + # A union of GaussianNB's yields the same prediction as a single GaussianNB clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] ) @@ -968,7 +968,8 @@ def test_cwnb_union_gnb_fit(): def test_cwnb_union_bnb_fit(): - # A union of BernoulliNB's yields the same prediction a single BernoulliNB (fit) + # A union of BernoulliNB's yields the same prediction as a single BernoulliNB + # (fit) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) @@ -981,7 +982,7 @@ def test_cwnb_union_bnb_fit(): def test_cwnb_union_bnb_partial_fit(): - # A union of BernoulliNB's yields the same prediction a single BernoulliNB + # A union of BernoulliNB's yields the same prediction as a single BernoulliNB # (partial_fit) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] From e1449a0d3660bd33356da03c0ddf010ba740cfa8 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 5 Jun 2022 02:28:56 -0400 Subject: [PATCH 049/102] Test union GaussianNBs matches single one when priors are specified Also added a comment to priors parameter in ColumnwiseNB docstring. --- sklearn/naive_bayes.py | 2 ++ sklearn/tests/test_naive_bayes.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 6cda0ff42d37f..37e104c04f43a 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1509,6 +1509,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Prior probabilities of classes. If unspecified, the priors are calculated as relative frequencies of classes in the training data. If str, the priors are taken from the estimator with the given name. + If array-like, the same priors might have to be specified manually in + each sub-estimator, in order to ensure consistent predictions. n_jobs : int, default=None Number of jobs to run in parallel. diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 229fb1c904d4c..968af84328afe 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -967,6 +967,25 @@ def test_cwnb_union_gnb(): assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) +def test_cwnb_union_prior_gnb(): + # A union of GaussianNB's yields the same prediction as a single GaussianNB + # when class priors are provided by user + priors = np.array([1 / 3, 2 / 3]) + clf1 = ColumnwiseNB( + nb_estimators=[ + ("g1", GaussianNB(priors=priors), [0]), + ("g2", GaussianNB(priors=priors), [1]), + ], + priors=priors, + ) + clf2 = GaussianNB(priors=priors) + clf1.fit(X, y) + clf2.fit(X, y) + assert_array_almost_equal(clf1.predict(X), clf2.predict(X), 8) + assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + + def test_cwnb_union_bnb_fit(): # A union of BernoulliNB's yields the same prediction as a single BernoulliNB # (fit) From 0b1829c89dd038375535a0d6906cfa01db5addbb Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 18 Jun 2022 18:03:28 -0400 Subject: [PATCH 050/102] Implement _BaseNB.predict_joint_log_proba method and test for it --- sklearn/naive_bayes.py | 30 ++++++++++++++++++++++++++++-- sklearn/tests/test_naive_bayes.py | 11 +++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index ca7be2d3799a3..4da0e0195acea 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -51,8 +51,10 @@ def _joint_log_likelihood(self, X): I.e. ``log P(c) + log P(x|c)`` for all rows x of X, as an array-like of shape (n_samples, n_classes). - predict, predict_proba, and predict_log_proba pass the input through - _check_X and handle it over to _joint_log_likelihood. + Public methods predict, predict_proba, predict_log_proba, and + predict_joint_log_proba pass the input through _check_X before handing it + over to _joint_log_likelihood. The term "joint log likelihood" is used + interchangibly with "joint log probability". """ @abstractmethod @@ -62,6 +64,30 @@ def _check_X(self, X): Only used in predict* methods. """ + def predict_joint_log_proba(self, X): + """Compute the joint log probability ``log P(X, y)``. + + For each row x of X and class y, the joint log probability is given by + ``log P(x, y) = log P(y) + log P(x|y),`` + where ``log P(y)`` is the class prior probability and ``log P(x|y)`` is + the class-conditional probability. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + + Returns + ------- + C : ndarray of shape (n_samples, n_classes) + Returns the joint log-probability of the samples for each class in + the model. The columns correspond to the classes in sorted + order, as they appear in the attribute :term:`classes_`. + """ + check_is_fitted(self) + X = self._check_X(X) + return self._joint_log_likelihood(X) + def predict(self, X): """ Perform classification on an array of test vectors X. diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 47fd6821ad305..59c904d4381ae 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -5,6 +5,8 @@ import pytest import warnings +from scipy.special import logsumexp + from sklearn.datasets import load_digits, load_iris from sklearn.model_selection import train_test_split @@ -945,3 +947,12 @@ def test_n_features_deprecation(Estimator): with pytest.warns(FutureWarning, match="`n_features_` was deprecated"): est.n_features_ + + +@pytest.mark.parametrize("Estimator", ALL_NAIVE_BAYES_CLASSES) +def test_predict_joint_proba(Estimator): + est = Estimator().fit(X2, y2) + jll = est.predict_joint_log_proba(X2) + log_prob_x = logsumexp(jll, axis=1) + log_prob_x_y = jll - np.atleast_2d(log_prob_x).T + assert_array_almost_equal(log_prob_x_y, est.predict_log_proba(X2), 8) From ec9033e47baa92fee6e271424261cee6b88a35a1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 18 Jun 2022 18:24:56 -0400 Subject: [PATCH 051/102] _BaseNB.predict_join_log_proba improve docstring --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 4da0e0195acea..4f6007d047ffd 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -65,10 +65,10 @@ def _check_X(self, X): """ def predict_joint_log_proba(self, X): - """Compute the joint log probability ``log P(X, y)``. + """Return joint log probability estimates for the test vector X. For each row x of X and class y, the joint log probability is given by - ``log P(x, y) = log P(y) + log P(x|y),`` + ``log P(x, y) = log P(y) + log P(x|y),`` where ``log P(y)`` is the class prior probability and ``log P(x|y)`` is the class-conditional probability. From c321f6387baf013abc9f32b3c2a937fd7e99ebf4 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 18 Jun 2022 19:31:40 -0400 Subject: [PATCH 052/102] Changelog entry --- doc/whats_new/v1.2.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 7b203e94968e0..55f30c9b5eb64 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -154,6 +154,12 @@ Changelog of a binary classification problem. :pr:`22518` by :user:`Arturo Amor `. +:mod:`sklearn.naive_bayes` +.......................... + +- |Feature| Add methods `predict_joint_log_proba` to all naive Bayes classifiers. + :pr:`23683` by :user:`Andrey Melnik `. + :mod:`sklearn.neighbors` ........................ From 2be313e8dd6f164be20557083cd2871af091b760 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 18 Jun 2022 19:57:14 -0400 Subject: [PATCH 053/102] Use predict_joint_log_proba instead of _joint_log_likelihood in sub-estimators --- sklearn/naive_bayes.py | 26 ++++++++++++-------------- sklearn/tests/test_naive_bayes.py | 22 ++++++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index ac6ab7301228b..215c4df20110a 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1484,11 +1484,11 @@ def _partial_fit_one(estimator, X, y, message_clsname="", message=None, **fit_pa def _jll_one(estimator, X): - """Call ``estimator._joint_log_likelihood``. + """Call ``estimator.predict_joint_log_proba``. See :func:`sklearn.pipeline._transform_one`. """ - return estimator._joint_log_likelihood(estimator._check_X(X)) + return estimator.predict_joint_log_proba(X) class ColumnwiseNB(_BaseNB, _BaseComposition): @@ -1516,10 +1516,10 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): nb_estimator : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the - estimator must support ``_joint_log_likelihood`` method, which + estimator must support ``predict_joint_log_proba`` method, which takes :term:`X` of shape (n_samples, n_features) and returns a numpy array of shape (n_samples, n_classes) containing joint - log-likelihoods, ``log P(x,c)`` for each sample point and class. + log-probabilities, ``log P(x,y)`` for each sample point and class. columns : str, array-like of str, int, array-like of int, \ array-like of bool, slice or callable Indexes the data on its second axis. Integers are interpreted as @@ -1654,11 +1654,7 @@ def _check_X(self, X): return X def _joint_log_likelihood(self, X): - """Calculate the meta-estimator's joint log likelihood ``log P(x,c)``.""" - # Because data must follow the same path as it would in subestimators, - # _jll_one(nb_estimator, X) passes it through nb_estimator._check_X to - # match the implementation of _BaseNB.predict_log_proba. - # Changes therein must be reflected in _jll_one or here. + """Calculate the meta-estimator's joint log-probability ``log P(x,y)``.""" estimators = self._iter(fitted=True, replace_strings=True) all_jlls = Parallel(n_jobs=self.n_jobs)( delayed(_jll_one)(estimator=nb_estimator, X=_safe_indexing(X, cols, axis=1)) @@ -1669,7 +1665,7 @@ def _joint_log_likelihood(self, X): return np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior def _validate_estimators(self, check_partial=False): - # Check if estimators have fit/partial_fit and jll methods + # Check if estimators have fit/partial_fit and joint log prob methods # Validate estimator names via _BaseComposition._validate_names(self, names) if not self.nb_estimators: raise ValueError( @@ -1679,18 +1675,20 @@ def _validate_estimators(self, check_partial=False): names, estimators, _ = zip(*self.nb_estimators) for e in estimators: if (not check_partial) and ( - not (hasattr(e, "fit") and hasattr(e, "_joint_log_likelihood")) + not (hasattr(e, "fit") and hasattr(e, "predict_joint_log_proba")) ): raise TypeError( "Estimators must be naive Bayes estimators implementing " - "`fit` and `_joint_log_likelihood` methods." + "`fit` and `predict_joint_log_proba` methods." ) if check_partial and ( - not (hasattr(e, "partial_fit") and hasattr(e, "_joint_log_likelihood")) + not ( + hasattr(e, "partial_fit") and hasattr(e, "predict_joint_log_proba") + ) ): raise TypeError( "Estimators must be Naive Bayes estimators implementing " - "`partial_fit` and `_joint_log_likelihood` methods." + "`partial_fit` and `predict_joint_log_proba` methods." ) self._validate_names(names) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index e22136423c6fa..4d26e214c5d78 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1206,7 +1206,7 @@ def test_cwnb_estimators_nonempty_list(): def test_cwnb_estimators_support_jll(): - # Subestimators spec: error when some don't support _joint_log_likelihood + # Subestimators spec: error when some don't support predict_joint_log_proba class notNB(BaseEstimator): def __init__(self): pass @@ -1217,7 +1217,7 @@ def fit(self, X, y): def partial_fit(self, X, y): pass - # def _joint_log_likelihood(self, X): pass + # def predict_joint_log_proba(self, X): pass def predict(self, X): pass @@ -1237,7 +1237,7 @@ def __init__(self): def partial_fit(self, X, y): pass - def _joint_log_likelihood(self, X): + def predict_joint_log_proba(self, X): pass def predict(self, X): @@ -1259,7 +1259,7 @@ def fit(self, X, y): pass # def partial_fit(self, X, y): pass - def _joint_log_likelihood(self, X): + def predict_joint_log_proba(self, X): pass def predict(self, X): @@ -1462,7 +1462,7 @@ def test_cwnb_fit_sample_weight_ones(): clf1.fit(X, y, sample_weight=weights) clf2.fit(X, y) assert_array_almost_equal( - clf1._joint_log_likelihood(X), clf2._joint_log_likelihood(X), 8 + clf1.predict_joint_log_proba(X), clf2.predict_joint_log_proba(X), 8 ) assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) assert_array_equal(clf1.predict(X), clf2.predict(X)) @@ -1486,7 +1486,7 @@ def test_cwnb_partial_fit_sample_weight_ones(): clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) clf2.partial_fit(X2, y2, classes=np.unique(y2)) assert_array_almost_equal( - clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 ) assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) @@ -1512,7 +1512,7 @@ def test_cwnb_fit_sample_weight_repeated(): clf1.fit(X, y, sample_weight=weights) clf2.fit(X[idx], y[idx]) assert_array_almost_equal( - clf1._joint_log_likelihood(X), clf2._joint_log_likelihood(X), 8 + clf1.predict_joint_log_proba(X), clf2.predict_joint_log_proba(X), 8 ) assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) assert_array_equal(clf1.predict(X), clf2.predict(X), 8) @@ -1538,7 +1538,9 @@ def test_cwnb_partial_fit_sample_weight_repeated(): ) clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) clf2.partial_fit(X2[idx], y2[idx], classes=np.unique(y2)) - assert_array_equal(clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2)) + assert_array_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2) + ) assert_array_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2)) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) for attr_name in ("class_count_", "class_prior_", "classes_"): @@ -1557,7 +1559,7 @@ def test_cwnb_partial_fit(): clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) clf2.partial_fit(X2[4:], y2[4:]) assert_array_almost_equal( - clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 ) assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) @@ -1642,7 +1644,7 @@ def test_cwnb_n_jobs(): clf2.partial_fit(X2, y2, classes=np.unique(y2)) assert_array_almost_equal( - clf1._joint_log_likelihood(X2), clf2._joint_log_likelihood(X2), 8 + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 ) assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) assert_array_equal(clf1.predict(X2), clf2.predict(X2)) From aae81e9d1ccaebf566a97289875f94af8aa67137 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 7 Jul 2022 18:59:55 -0400 Subject: [PATCH 054/102] Common parameter validation towards #23462 and custom test Custom test of the changes is implemented in test_naive_bayes.py::test_cwnb_check_param_validation --- sklearn/naive_bayes.py | 9 +++++++++ sklearn/tests/test_naive_bayes.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 4613dcb34dc2a..9c2ae1e3d9d7b 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1667,6 +1667,13 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): _required_parameters = ["nb_estimators"] + _parameter_constraints = { + "nb_estimators": "no_validation", + "priors": ["array-like", str, None], + "n_jobs": [Integral, None], + "verbose": ["verbose"], + } + def _log_message(self, name, idx, total): if not self.verbose: return None @@ -1888,6 +1895,7 @@ def fit(self, X, y, sample_weight=None): self : object Returns the instance itself. """ + self._validate_params() self._check_feature_names(X, reset=True) # TODO: Consider overriding BaseEstimator._check_feature_names # Currently, when X has all str feature names, all features are @@ -1963,6 +1971,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): """ first_call = not hasattr(self, "classes_") if first_call: + self._validate_params() self._check_feature_names(X, reset=True) self._check_n_features(X, reset=True) self._validate_estimators(check_partial=True) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 355d5d1db619e..0918d3645fe1e 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -17,6 +17,7 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_allclose +from sklearn.utils.estimator_checks import check_param_validation from sklearn.naive_bayes import GaussianNB, BernoulliNB from sklearn.naive_bayes import MultinomialNB, ComplementNB @@ -1671,3 +1672,14 @@ def test_cwnb_sk_visual_block(capsys): assert visual_block.names == ("mnb1", "mnb2", "gnb1") assert visual_block.name_details == ([0, 1], [3, 4], [5]) assert visual_block.estimators == estimators + + +def test_cwnb_check_param_validation(): + # This test replaces test_common.py::test_check_param_validation and is + # needed because utils.estimator_checks._construct_instance() is unable to + # create an instance of ColumnwiseNB (also of some other estimators, such as + # ColumnTransformer and Pipeline). + clf = ColumnwiseNB( + nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + check_param_validation("ColumnwiseNB", clf) From 4ba7bac8c33ccb5e47045418b9690f554b6d2ce5 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 7 Jul 2022 19:43:30 -0400 Subject: [PATCH 055/102] Docs terminology log-likelihood -> log-probability --- doc/modules/naive_bayes.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 3fc7e039027b7..aa19353b6f459 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -305,8 +305,8 @@ combined via \log P(x,y)=\log P(x_{1},y) + \dots + \log P(x_{M},y) - (M - 1)\log P(y), -where :math:`\log P(x,y)` is the joint log-likelihood of the meta-estimator, -:math:`\log P(x_{m},y)` is that of the :math:`m` th sub-estimator, +where :math:`\log P(x,y)` is the joint log-probability predicted by the meta-estimator, +:math:`\log P(x_{m},y)` is that by the :math:`m` th sub-estimator, :math:`\log P(y)` is the class prior used by the meta-estimator, and :math:`M\geq1` is the total number of sub-estimators. From 1224e9ac1556151bf40e1a2184075408cb086e43 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 12 Aug 2022 17:30:40 -0400 Subject: [PATCH 056/102] Empty commit to trigger pipeline From 397cc1b26a622d918436e6066fa3381250ea957f Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 17 Oct 2022 23:15:23 -0400 Subject: [PATCH 057/102] Parameter parser='auto' in fetch_openml. See #21938 --- examples/miscellaneous/plot_combining_naive_bayes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index 3f0845492fb51..a2d0473fba3ef 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -27,7 +27,9 @@ import pandas as pd from sklearn.datasets import fetch_openml -X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True, n_retries=10) +X, y = fetch_openml( + "titanic", version=1, as_frame=True, return_X_y=True, n_retries=10, parser="auto" +) X["pclass"] = X["pclass"].astype("category") # Add a category for NaNs to the "embarked" feature: X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A") From 33622aeea2242297aa2f35943675178a4392a92c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 18 Oct 2022 16:01:51 -0400 Subject: [PATCH 058/102] Use set_config(transform_output=pandas) and string feature names. See SLEP018 and #23734 --- .../plot_combining_naive_bayes.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index a2d0473fba3ef..2f38b0d2f1e13 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -25,8 +25,11 @@ # %% import pandas as pd +from sklearn import set_config from sklearn.datasets import fetch_openml +set_config(transform_output="pandas") + X, y = fetch_openml( "titanic", version=1, as_frame=True, return_X_y=True, n_retries=10, parser="auto" ) @@ -57,13 +60,14 @@ transformers=[ ("num", numeric_transformer, numeric_features), ("cat", categorical_transformer, categorical_features), - ] + ], + verbose_feature_names_out=False, ) classifier = ColumnwiseNB( nb_estimators=[ - ("gnb", GaussianNB(), [0, 1]), - ("cnb", CategoricalNB(), [2, 3, 4]), + ("gnb", GaussianNB(), numeric_features), + ("cnb", CategoricalNB(), categorical_features), ] ) @@ -86,9 +90,12 @@ param_grid = { "classifier__nb_estimators": [ - [("gnb", GaussianNB(), [0, 1]), ("cnb", CategoricalNB(), [2, 3, 4])], - [("gnb", GaussianNB(), []), ("cnb", CategoricalNB(), [3])], - [("gnb", GaussianNB(), [3]), ("cnb", CategoricalNB(), [])], + [ + ("gnb", GaussianNB(), ["age", "fare"]), + ("cnb", CategoricalNB(), categorical_features), + ], + [("gnb", GaussianNB(), []), ("cnb", CategoricalNB(), ["pclass"])], + [("gnb", GaussianNB(), ["embarked"]), ("cnb", CategoricalNB(), [])], ], "preprocessor__num__strategy": ["mean", "most_frequent"], } From a5ea3243fd8635debd5798bddae69fb6f742dc26 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 15:14:41 +0100 Subject: [PATCH 059/102] DOC update changelog --- doc/whats_new/v1.2.rst | 4 ---- doc/whats_new/v1.3.rst | 6 ++++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 45a4ba3124e0e..6442979db402e 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -621,10 +621,6 @@ Changelog :mod:`sklearn.naive_bayes` .......................... -- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows - existing naive Bayes classifiers to be combined and applied to different columns - of `X`. :pr:`22574` by :user:`Andrey Melnik `. - - |Feature| Add methods `predict_joint_log_proba` to all naive Bayes classifiers. :pr:`23683` by :user:`Andrey Melnik `. diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index eb9f0cc473e27..cd9f2957d7fef 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -48,6 +48,12 @@ Changelog :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`. :pr:`25177` by :user:`Tim Head `. +:mod:`sklearn.naive_bayes` +.......................... +- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows + existing naive Bayes classifiers to be combined and applied to different columns + of `X`. :pr:`22574` by :user:`Andrey Melnik `. + :mod:`sklearn.pipeline` ....................... - |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g. From 3743b17b39256048ba8ad465fc88c93f5cc6af26 Mon Sep 17 00:00:00 2001 From: avm19 Date: Wed, 28 Dec 2022 23:18:02 -0500 Subject: [PATCH 060/102] Minor format suggestions from glemaitre's review Co-authored-by: Guillaume Lemaitre --- sklearn/naive_bayes.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index c1734d8810cce..bca27f69344b0 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -34,9 +34,7 @@ from .utils.validation import _check_sample_weight from .utils.validation import column_or_1d from .utils.metaestimators import _BaseComposition -from .utils import _safe_indexing, _get_column_indices -from .utils import _print_elapsed_time -from .utils import Bunch +from .utils import _safe_indexing, _get_column_indices, _print_elapsed_time, Bunch from .utils.fixes import delayed from .utils._estimator_html_repr import _VisualBlock from .compose._column_transformer import _is_empty_column_selection @@ -1570,8 +1568,7 @@ def _jll_one(estimator, X): class ColumnwiseNB(_BaseNB, _BaseComposition): - """ - Column-wise Naive Bayes meta-estimator. + """Column-wise Naive Bayes meta-estimator. This estimator combines various naive Bayes estimators by applying them to different column subsets of the input and joining their predictions @@ -1598,7 +1595,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): takes :term:`X` of shape (n_samples, n_features) and returns a numpy array of shape (n_samples, n_classes) containing joint log-probabilities, ``log P(x,y)`` for each sample point and class. - columns : str, array-like of str, int, array-like of int, \ + columns : str, array-like of str, int, array-like of int, \ array-like of bool, slice or callable Indexes the data on its second axis. Integers are interpreted as positional columns, while strings can reference DataFrame columns @@ -1629,12 +1626,12 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Attributes ---------- estimators_ : list of tuples - List of ``(name, fitted_nb_estimator, columns)`` tuples, which follow - the order of `nb_estimators`. Here, ``fitted_nb_estimator`` is a fitted naive - Bayes estimator, except when ``columns`` presents an empty selection of - columns, in which case it is the original unfitted ``nb_estimator``. If - the original specification of ``columns`` in ``nb_estimators`` was a - callable, then ``columns`` is converted to a list of column indices. + List of `(name, fitted_estimator, columns)` tuples, which follow + the order of `estimators`. Here, `fitted_estimator` is a fitted naive + Bayes estimator, except when `columns` presents an empty selection of + columns, in which case it is the original unfitted `nb_estimator`. If + the original specification of `columns` in `estimators` was a + callable, then `columns` is converted to a list of column indices. named_estimators_ : :class:`~sklearn.utils.Bunch` Read-only attribute to access any subestimator by given name. From 75a00f49a2561e587756ff56cf50231a8b76e01c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 29 Dec 2022 12:34:38 -0500 Subject: [PATCH 061/102] Docstring: versionadded note and a reference to the User Guide --- sklearn/naive_bayes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index bca27f69344b0..93f5ba67948c7 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1575,6 +1575,10 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): according to the naive Bayes assumption. This is useful when features are heterogeneous and follow different kinds of distributions. + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.3 + Parameters ---------- nb_estimators : list of tuples From e70e6c8ffb0c1da5cf133bc609692fd75ae093d2 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 29 Dec 2022 17:54:49 -0500 Subject: [PATCH 062/102] Docstring for n_features_in_ --- sklearn/naive_bayes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 93f5ba67948c7..da95d6e4a75b7 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1659,6 +1659,9 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): classes_ : ndarray of shape (n_classes,) Class labels known to the classifier. + n_features_in_ : int + Number of features seen during :term:`fit`. + feature_names_in_ : ndarray of shape (`n_features_in_`,) Names of features seen during :term:`fit`. Only defined if `X` has feature names that are all strings. From f19914d23fa98f01fe9b5088556cd46f3af2b2f0 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 31 Dec 2022 20:24:14 -0500 Subject: [PATCH 063/102] Remove n_classes_ attribute --- sklearn/naive_bayes.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 91933cb1adc25..2780f48bfde72 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1655,9 +1655,6 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Number of samples encountered for each class during fitting. This value is weighted by the sample weight when provided. - n_classes_ : int - The number of classes known to the naive Bayes classifier, `n_classes`. - classes_ : ndarray of shape (n_classes,) Class labels known to the classifier. @@ -1895,7 +1892,7 @@ def _update_class_prior(self): else: # check the provided prior priors = np.asarray(self.priors) # Check the prior in any case. - if len(priors) != self.n_classes_: + if len(priors) != len(self.classes_): raise ValueError("Number of priors must match number of classes.") if not np.isclose(priors.sum(), 1.0): raise ValueError("The sum of the priors should be 1.") @@ -1969,7 +1966,6 @@ def fit(self, X, y, sample_weight=None): for i, c in enumerate(self.classes_): counts[i] = (weights * (column_or_1d(y) == c)).sum() self.class_count_ = counts - self.n_classes_ = len(self.classes_) estimators = list(self._iter(fitted=False, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( @@ -2049,7 +2045,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): counts[i] = (weights * (column_or_1d(y) == c)).sum() if first_call: - self.n_classes_ = len(self.classes_) self.class_count_ = counts else: self.class_count_ += counts From 23469d798bd06705f498fd764cdd6451b64455b3 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 31 Dec 2022 20:42:22 -0500 Subject: [PATCH 064/102] named_estimators_ is now not a property, but a field --- sklearn/naive_bayes.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 2780f48bfde72..d55b8e08c4f5d 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1802,18 +1802,6 @@ def _validate_column_callables(self, X): self._columns = all_columns self._estimator_to_input_indices = estimator_to_input_indices - @property - def named_estimators_(self): - """Access the fitted naive Bayes subestimators by name. - - Read-only attribute to access any estimators by given name. - Keys are estimators names and values are the fitted estimator - objects. - """ - # Almost a verbatim copy of ColumnTransformer.named_transformers_ - # Use Bunch object to improve autocomplete - return Bunch(**{name: e for name, e, _ in self.estimators_}) - def _iter(self, *, fitted=False, replace_strings=False): """Generate ``(name, nb_estimator, columns)`` tuples. @@ -1917,6 +1905,7 @@ def _update_fitted_estimators(self, fitted_estimators): updated_nb_estimator = nb_estimator estimators_.append((name, updated_nb_estimator, cols)) self.estimators_ = estimators_ + self.named_estimators_ = Bunch(**{name: e for name, e, _ in estimators_}) def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. From 9683a2b711d3d22f114b0440395586e4d42230bb Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sat, 31 Dec 2022 21:54:25 -0500 Subject: [PATCH 065/102] Minor formatting: f-string in place of old style --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d55b8e08c4f5d..501318d04f801 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1722,7 +1722,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): def _log_message(self, name, idx, total): if not self.verbose: return None - return "(%d of %d) Processing %s" % (idx, total, name) + return f"({idx} of {total}) Processing {name}" def __init__(self, nb_estimators, *, priors=None, n_jobs=None, verbose=False): self.nb_estimators = nb_estimators From 03ef84d2382ed1f50567629b5808819e41267762 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 2 Jan 2023 01:58:44 -0500 Subject: [PATCH 066/102] Factor out _fit_partial from fit and fit_partial --- sklearn/naive_bayes.py | 156 +++++++++++++++++------------------------ 1 file changed, 64 insertions(+), 92 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 501318d04f801..7d0105801d911 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1548,6 +1548,9 @@ def _fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): See :func:`sklearn.pipeline._fit_one`. """ + # The dummy parameter is needed in _fit_partial to factorise fit/fit_partial + if fit_params["classes"] is None: + fit_params.pop("classes") with _print_elapsed_time(message_clsname, message): return estimator.fit(X, y, **fit_params) @@ -1907,6 +1910,63 @@ def _update_fitted_estimators(self, fitted_estimators): self.estimators_ = estimators_ self.named_estimators_ = Bunch(**{name: e for name, e, _ in estimators_}) + def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): + """ + partial : bool, default=False + True for partial_fit, False for fit. + """ + first_call = not hasattr(self, "classes_") + if first_call: # in fit() or the first call of partial_fit() + self._validate_params() + self._check_feature_names(X, reset=True) + self._check_n_features(X, reset=True) + self._validate_estimators(check_partial=partial) + self._validate_column_callables(X) + else: + self._check_feature_names(X, reset=False) + self._check_n_features(X, reset=False) + + y_ = column_or_1d(y) + + if sample_weight is not None: + weights = _check_sample_weight(sample_weight, X=y_, copy=True) + + if not partial: + self.classes_, counts = np.unique(y_, return_counts=True) + else: + _check_partial_fit_first_call(self, classes) + + if sample_weight is not None: + counts = np.zeros(len(self.classes_)) + for i, c in enumerate(self.classes_): + counts[i] = (weights * (y_ == c)).sum() + elif partial: + counts = np.zeros(len(self.classes_)) + for i, c in enumerate(self.classes_): + counts[i] = (y_ == c).sum() + + if not first_call: + self.class_count_ += counts + else: + self.class_count_ = counts + + estimators = list(self._iter(fitted=not first_call, replace_strings=True)) + fitted_estimators = Parallel(n_jobs=self.n_jobs)( + delayed(_partial_fit_one if partial else _fit_one)( + estimator=clone(nb_estimator) if first_call else nb_estimator, + X=_safe_indexing(X, cols, axis=1), + y=y, + message_clsname="ColumnwiseNB", + message=self._log_message(name, idx, len(estimators)), + classes=classes, + sample_weight=sample_weight, + ) + for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) + ) + self._update_fitted_estimators(fitted_estimators) + self._update_class_prior() + return self + def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. @@ -1929,48 +1989,9 @@ def fit(self, X, y, sample_weight=None): self : object Returns the instance itself. """ - self._validate_params() - self._check_feature_names(X, reset=True) - # TODO: Consider overriding BaseEstimator._check_feature_names - # Currently, when X has all str feature names, all features are - # registered in self.feature_names_in no matter if they are used or not. - self._check_n_features(X, reset=True) - self._validate_estimators() - self._validate_column_callables(X) - # Consistency checks for X, y are delegated to subestimators - - # Subestimators get original sample_weight. This is for class counts: - if sample_weight is not None: - weights = _check_sample_weight(sample_weight, X=y, copy=True) - - # We would use sklearn.utils.multiclass.class_distribution, but it does - # not return class_count, which we want as well. - if sample_weight is None: - self.classes_, self.class_count_ = np.unique( - column_or_1d(y), return_counts=True - ) - else: - self.classes_ = np.unique(column_or_1d(y)) - counts = np.zeros(len(self.classes_)) - for i, c in enumerate(self.classes_): - counts[i] = (weights * (column_or_1d(y) == c)).sum() - self.class_count_ = counts - - estimators = list(self._iter(fitted=False, replace_strings=True)) - fitted_estimators = Parallel(n_jobs=self.n_jobs)( - delayed(_fit_one)( - estimator=clone(nb_estimator), - X=_safe_indexing(X, cols, axis=1), - y=y, - message_clsname="ColumnwiseNB", - message=self._log_message(name, idx, len(estimators)), - sample_weight=sample_weight, - ) - for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) + return self._partial_fit( + X, y, partial=False, classes=None, sample_weight=sample_weight ) - self._update_fitted_estimators(fitted_estimators) - self._update_class_prior() - return self def partial_fit(self, X, y, classes=None, sample_weight=None): """Fit incrementally the naive Bayes meta-estimator on a batch of samples. @@ -2002,58 +2023,9 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): self : object Returns the instance itself. """ - first_call = not hasattr(self, "classes_") - if first_call: - self._validate_params() - self._check_feature_names(X, reset=True) - self._check_n_features(X, reset=True) - self._validate_estimators(check_partial=True) - self._validate_column_callables(X) - else: - self._check_feature_names(X, reset=False) - self._check_n_features(X, reset=False) - # Consistency checks for X, y are delegated to subestimators - - # Subestimators get original sample_weight. This is for class counts: - if sample_weight is not None: - weights = _check_sample_weight(sample_weight, X=y, copy=True) - - # Subestimators should've checked classes. We set classes_ for counts - # and so that first_call becomes False at next partial_fit call. - _check_partial_fit_first_call(self, classes) - - # We don't use sklearn.utils.multiclass.class_distribution, because it - # neither returns class_count, nor is suitable for partial_fit. - if sample_weight is None: - counts = np.zeros(len(self.classes_)) - for i, c in enumerate(self.classes_): - counts[i] = (column_or_1d(y) == c).sum() - else: - counts = np.zeros(len(self.classes_)) - for i, c in enumerate(self.classes_): - counts[i] = (weights * (column_or_1d(y) == c)).sum() - - if first_call: - self.class_count_ = counts - else: - self.class_count_ += counts - - estimators = list(self._iter(fitted=not first_call, replace_strings=True)) - fitted_estimators = Parallel(n_jobs=self.n_jobs)( - delayed(_partial_fit_one)( - estimator=clone(nb_estimator) if first_call else nb_estimator, - X=_safe_indexing(X, cols, axis=1), - y=y, - message_clsname="ColumnwiseNB", - message=self._log_message(name, idx, len(estimators)), - classes=classes, - sample_weight=sample_weight, - ) - for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) + return self._partial_fit( + X, y, partial=True, classes=classes, sample_weight=sample_weight ) - self._update_fitted_estimators(fitted_estimators) - self._update_class_prior() - return self @property def _estimators(self): From 1198f25a7060d7e2e44570eb52285375f3d3b029 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 2 Jan 2023 12:04:44 -0500 Subject: [PATCH 067/102] Ensure ColumnwiseNB.class_count_ is float64 --- sklearn/naive_bayes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 7d0105801d911..3d1d16f6aa616 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1937,18 +1937,18 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): _check_partial_fit_first_call(self, classes) if sample_weight is not None: - counts = np.zeros(len(self.classes_)) + counts = np.zeros(len(self.classes_), dtype=np.float64) for i, c in enumerate(self.classes_): counts[i] = (weights * (y_ == c)).sum() elif partial: - counts = np.zeros(len(self.classes_)) + counts = np.zeros(len(self.classes_), dtype=np.float64) for i, c in enumerate(self.classes_): counts[i] = (y_ == c).sum() if not first_call: self.class_count_ += counts else: - self.class_count_ = counts + self.class_count_ = counts.astype(np.float64, copy=False) estimators = list(self._iter(fitted=not first_call, replace_strings=True)) fitted_estimators = Parallel(n_jobs=self.n_jobs)( From 21ae456d53c2dddc5c75a3820f85548111837da5 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 3 Jan 2023 15:15:35 -0500 Subject: [PATCH 068/102] Docstring: clarify callable columns are evaluated only once --- sklearn/naive_bayes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 3d1d16f6aa616..1a9ceafe5b9bd 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1613,7 +1613,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): otherwise a 2d array will be passed to the transformer. A callable is passed the input data `X` and can return any of the above. To select multiple columns by name or dtype, you can use - :obj:`~sklearn.compose.make_column_selector`. + :obj:`~sklearn.compose.make_column_selector`. The callable is evaluated + on the first batch, but not on subsequent calls of `partial_fit`. priors : array-like of shape (n_classes,) or str, default=None Prior probabilities of classes. If unspecified, the priors are From 51b3a39e3a22ff6c8e4aa64283896461368a5418 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 3 Jan 2023 21:41:41 -0500 Subject: [PATCH 069/102] Decorate partial_fit with available_if --- sklearn/naive_bayes.py | 24 +++++++++++++++++++++++- sklearn/tests/test_naive_bayes.py | 4 ++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 1a9ceafe5b9bd..5a5667852728b 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -33,7 +33,7 @@ from .utils.validation import check_is_fitted, check_non_negative from .utils.validation import _check_sample_weight from .utils.validation import column_or_1d -from .utils.metaestimators import _BaseComposition +from .utils.metaestimators import _BaseComposition, available_if from .utils import _safe_indexing, _get_column_indices, _print_elapsed_time, Bunch from .utils.fixes import delayed from .utils._estimator_html_repr import _VisualBlock @@ -1543,6 +1543,27 @@ def _joint_log_likelihood(self, X): return total_ll +def _nb_estimators_have(attr): + """Check if all self.nb_estimators or self.nb_estimators_ have attr. + + Used together with `available_if` in `ColumnwiseNB`.""" + + # This function is used with `_available_if` before validation. + # The try statement suppresses errors caused by incorrect specification of + # self.nb_estimators. Informative errors are raised at validation elsewhere. + def chk(obj): + try: + if hasattr(obj, "nb_estimators_"): + out = all(hasattr(triplet[1], attr) for triplet in obj.nb_estimators_) + else: + out = all(hasattr(triplet[1], attr) for triplet in obj.nb_estimators) + except (TypeError, IndexError, AttributeError): + return False + return out + + return chk + + def _fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): """Call ``estimator.fit`` and print elapsed time message. @@ -1994,6 +2015,7 @@ def fit(self, X, y, sample_weight=None): X, y, partial=False, classes=None, sample_weight=sample_weight ) + @available_if(_nb_estimators_have("partial_fit")) def partial_fit(self, X, y, classes=None, sample_weight=None): """Fit incrementally the naive Bayes meta-estimator on a batch of samples. diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index e3485aea925e4..9b68c330259f4 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1283,8 +1283,8 @@ def predict(self, X): pass clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) - msg = "Estimators must be .aive Bayes estimators implementing *" - with pytest.raises(TypeError, match=msg): + msg = "This 'ColumnwiseNB' has no attribute 'partial_fit'*" + with pytest.raises(AttributeError, match=msg): clf1.partial_fit(X, y) From 0eb36a93565daeba5b7fb97e771d3a4497b7ef0c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 3 Jan 2023 21:50:51 -0500 Subject: [PATCH 070/102] Improve _validate_estimators for when non-tuples are passed --- sklearn/naive_bayes.py | 7 ++++--- sklearn/tests/test_naive_bayes.py | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 5a5667852728b..0d685a7e07ae5 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1780,12 +1780,13 @@ def _joint_log_likelihood(self, X): def _validate_estimators(self, check_partial=False): # Check if estimators have fit/partial_fit and joint log prob methods # Validate estimator names via _BaseComposition._validate_names(self, names) - if not self.nb_estimators: + try: + names, estimators, _ = zip(*self.nb_estimators) + except (TypeError, AttributeError, ValueError) as exc: raise ValueError( "A list of naive Bayes estimators must be provided " "in the form [(name, nb_estimator, columns), ... ]." - ) - names, estimators, _ = zip(*self.nb_estimators) + ) from exc for e in estimators: if (not check_partial) and ( not (hasattr(e, "fit") and hasattr(e, "predict_joint_log_proba")) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 9b68c330259f4..766e6671e0698 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1221,6 +1221,13 @@ def test_cwnb_estimators_nonempty_list(): with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) + # Subestimators spec: error on non-tuple + clf = ColumnwiseNB( + nb_estimators=GaussianNB(), + ) + msg = "A list of naive Bayes estimators must be provided*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) def test_cwnb_estimators_support_jll(): # Subestimators spec: error when some don't support predict_joint_log_proba From 8d27d6e83f98ae3df1a1353998552b703d4f53d1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 3 Jan 2023 22:15:27 -0500 Subject: [PATCH 071/102] Improve _validate_estimators and test --- sklearn/naive_bayes.py | 6 +----- sklearn/tests/test_naive_bayes.py | 7 +++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 0d685a7e07ae5..a25ad2816faab 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1795,11 +1795,7 @@ def _validate_estimators(self, check_partial=False): "Estimators must be naive Bayes estimators implementing " "`fit` and `predict_joint_log_proba` methods." ) - if check_partial and ( - not ( - hasattr(e, "partial_fit") and hasattr(e, "predict_joint_log_proba") - ) - ): + if check_partial and not hasattr(e, "predict_joint_log_proba"): raise TypeError( "Estimators must be Naive Bayes estimators implementing " "`partial_fit` and `predict_joint_log_proba` methods." diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 766e6671e0698..1ae0d0089ee16 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1229,6 +1229,7 @@ def test_cwnb_estimators_nonempty_list(): with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) + def test_cwnb_estimators_support_jll(): # Subestimators spec: error when some don't support predict_joint_log_proba class notNB(BaseEstimator): @@ -1272,6 +1273,12 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.fit(X, y) + delattr(notNB, 'predict_joint_log_proba') + clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.fit(X, y) + def test_cwnb_estimators_support_partial_fit(): # Subestimators spec: error when some don't support partial_fit From 298f3aef6d9920a609f04ff0d1f0a7e30f418a70 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 4 Jan 2023 00:00:13 -0500 Subject: [PATCH 072/102] black formatting --- sklearn/tests/test_naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 1ae0d0089ee16..a9a75c6f92dd1 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1273,7 +1273,7 @@ def predict(self, X): with pytest.raises(TypeError, match=msg): clf1.fit(X, y) - delattr(notNB, 'predict_joint_log_proba') + delattr(notNB, "predict_joint_log_proba") clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): From 09eaf7a05598748165554ae85d4c1aca11efbfc2 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 4 Jan 2023 08:45:30 -0500 Subject: [PATCH 073/102] Use .utils._encode._unique instead of np.unique --- sklearn/naive_bayes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index a25ad2816faab..9ec842f69ea48 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -36,6 +36,7 @@ from .utils.metaestimators import _BaseComposition, available_if from .utils import _safe_indexing, _get_column_indices, _print_elapsed_time, Bunch from .utils.fixes import delayed +from .utils._encode import _unique from .utils._estimator_html_repr import _VisualBlock from .compose._column_transformer import _is_empty_column_selection @@ -1951,7 +1952,7 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): weights = _check_sample_weight(sample_weight, X=y_, copy=True) if not partial: - self.classes_, counts = np.unique(y_, return_counts=True) + self.classes_, counts = _unique(y_, return_counts=True) else: _check_partial_fit_first_call(self, classes) From b38206b62f3d04001e3eb5b4ccf2b41d95f45da7 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 4 Jan 2023 15:44:20 -0500 Subject: [PATCH 074/102] Docstring: replace double backticks with single backticks --- sklearn/naive_bayes.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 9ec842f69ea48..5b081e1b470ca 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1622,16 +1622,16 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): nb_estimator : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the - estimator must support ``predict_joint_log_proba`` method, which + estimator must support `predict_joint_log_proba` method, which takes :term:`X` of shape (n_samples, n_features) and returns a numpy array of shape (n_samples, n_classes) containing joint - log-probabilities, ``log P(x,y)`` for each sample point and class. + log-probabilities, `log P(x,y)` for each sample point and class. columns : str, array-like of str, int, array-like of int, \ array-like of bool, slice or callable Indexes the data on its second axis. Integers are interpreted as positional columns, while strings can reference DataFrame columns by name. A scalar string or int should be used where - ``nb_estimator`` expects X to be a 1d array-like (vector), + `nb_estimator` expects X to be a 1d array-like (vector), otherwise a 2d array will be passed to the transformer. A callable is passed the input data `X` and can return any of the above. To select multiple columns by name or dtype, you can use @@ -1647,8 +1647,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): n_jobs : int, default=None Number of jobs to run in parallel. - ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. - ``-1`` means using all processors. See :term:`Glossary ` + `None` means 1 unless in a :obj:`joblib.parallel_backend` context. + `-1` means using all processors. See :term:`Glossary ` for more details. verbose : bool, default=False @@ -1668,7 +1668,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): named_estimators_ : :class:`~sklearn.utils.Bunch` Read-only attribute to access any subestimator by given name. Keys are estimator names and values are the fitted estimators, except - when a subestimator does not require fitting (i.e., when ``columns`` is + when a subestimator does not require fitting (i.e., when `columns` is an empty set of indices). class_prior_ : ndarray of shape (n_classes,) @@ -1703,19 +1703,19 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Notes ----- ColumnwiseNB combines multiple naive Bayes estimators by expressing the - overall joint probability ``P(x,y)`` through ``P(x_i,y)``, the joint + overall joint probability `P(x,y)` through `P(x_i,y)`, the joint probabilities of the subestimators:: - ``Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y)``, + Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y), - where ``N`` denotes ``n_estimators``, the number of estimators. + where `N` denotes `n_estimators`, the number of estimators. It is implicitly assumed that the class log priors are finite and agree between the estimators and the subestimator:: - ``- inf < Log P(y) = Log P(y|1) = ... = Log P(y|N)``. + - inf < Log P(y) = Log P(y|1) = ... = Log P(y|N). The meta-estimators does not check if this condition holds. Meaningless - results, including ``NaN``, may be produced by ColumnwiseNB if the class + results, including `NaN`, may be produced by ColumnwiseNB if the class priors differ or contain a zero probability. Examples @@ -1768,7 +1768,7 @@ def _check_X(self, X): return X def _joint_log_likelihood(self, X): - """Calculate the meta-estimator's joint log-probability ``log P(x,y)``.""" + """Calculate the meta-estimator's joint log-probability `log P(x,y)`.""" estimators = self._iter(fitted=True, replace_strings=True) all_jlls = Parallel(n_jobs=self.n_jobs)( delayed(_jll_one)(estimator=nb_estimator, X=_safe_indexing(X, cols, axis=1)) @@ -1826,7 +1826,7 @@ def _validate_column_callables(self, X): self._estimator_to_input_indices = estimator_to_input_indices def _iter(self, *, fitted=False, replace_strings=False): - """Generate ``(name, nb_estimator, columns)`` tuples. + """Generate `(name, nb_estimator, columns)` tuples. This is a private method, similar to ColumnTransformer._iter. Must not be called before _validate_column_callables. @@ -1848,7 +1848,7 @@ def _iter(self, *, fitted=False, replace_strings=False): Yields ------ tuple - of the form ``(name, nb_estimator, columns)``. + of the form `(name, nb_estimator, columns)`. Notes ----- @@ -1990,7 +1990,7 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. - Calls `fit` of each subestimator ``nb_estimator``. Only a corresponding + Calls `fit` of each subestimator `nb_estimator`. Only a corresponding subset of columns of `X` is passed to each subestimator; `sample_weight` and `y` are passed to the subestimators as they are. @@ -2097,7 +2097,7 @@ def get_params(self, deep=True): def set_params(self, **kwargs): """Set the parameters of this estimator. - Valid parameter keys can be listed with ``get_params()``. Note that you + Valid parameter keys can be listed with `get_params()`. Note that you can directly set the parameters of the estimators contained in `estimators` of `ColumnwiseNB`. From 844afec7afba0a60c23a4d699bc2b9507dd47e7f Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 5 Jan 2023 00:03:11 -0500 Subject: [PATCH 075/102] Remove and/or correct comments --- sklearn/naive_bayes.py | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 5b081e1b470ca..2c09277c0d4e6 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1613,19 +1613,15 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): estimators to be combined into a single naive Bayes meta-estimator. name : str - Name of the naive Bayes estimator. Like in - :class:`~sklearn.pipeline.Pipeline`, - :class:`~sklearn.pipeline.FeatureUnion`, - and :class:`~sklearn.compose.ColumnTransformer`, this allows the - subestimator and its parameters to be set using :term:`set_params` - and searched in grid search. + Name of the naive Bayes estimator, by which the subestimator and + its parameters can be set using :term:`set_params` and searched in + grid search. nb_estimator : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the estimator must support `predict_joint_log_proba` method, which - takes :term:`X` of shape (n_samples, n_features) and returns a - numpy array of shape (n_samples, n_classes) containing joint - log-probabilities, `log P(x,y)` for each sample point and class. + returns a numpy array of shape (n_samples, n_classes) containing + joint log-probabilities, `log P(x,y)` for each sample point and class. columns : str, array-like of str, int, array-like of int, \ array-like of bool, slice or callable Indexes the data on its second axis. Integers are interpreted as @@ -1643,10 +1639,11 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): calculated as relative frequencies of classes in the training data. If str, the priors are taken from the estimator with the given name. If array-like, the same priors might have to be specified manually in - each sub-estimator, in order to ensure consistent predictions. + each subestimator, in order to ensure consistent predictions. n_jobs : int, default=None - Number of jobs to run in parallel. + Number of jobs to run in parallel. Appropriate fit or predict methods + of subestimators are invoked in parallel. `None` means 1 unless in a :obj:`joblib.parallel_backend` context. `-1` means using all processors. See :term:`Glossary ` for more details. @@ -1779,8 +1776,6 @@ def _joint_log_likelihood(self, X): return np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior def _validate_estimators(self, check_partial=False): - # Check if estimators have fit/partial_fit and joint log prob methods - # Validate estimator names via _BaseComposition._validate_names(self, names) try: names, estimators, _ = zip(*self.nb_estimators) except (TypeError, AttributeError, ValueError) as exc: @@ -1809,12 +1804,6 @@ def _validate_column_callables(self, X): Empty-set columns do not enjoy any special treatment. """ - # Almost a verbatim copy of ColumnTransformer._validate_column_callables(). - # Consider refactoring in the future. - # Unlike ColumnTransformer, this estimator does not need to output a - # dataframe or validate a the remainder, so _estimator_to_input_indices - # is not really needed, but retained for consistency with - # ColumnTransformer code. all_columns = [] estimator_to_input_indices = {} for name, _, columns in self.nb_estimators: @@ -2055,7 +2044,6 @@ def _estimators(self): This is for the implementation of get_params via BaseComposition._get_params, which expects lists of tuples of len 2. """ - # Implemented in the image and likeness of ColumnTranformer._transformers try: return [(name, e) for name, e, _ in self.nb_estimators] except (TypeError, ValueError): @@ -2066,8 +2054,6 @@ def _estimators(self): @_estimators.setter def _estimators(self, value): - # Implemented in the image and likeness of ColumnTranformer._transformers - # TODO: Is renaming or changing the order legal? Swap `name` and `_`? self.nb_estimators = [ (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.nb_estimators) @@ -2077,7 +2063,7 @@ def get_params(self, deep=True): """Get parameters for this estimator. Returns the parameters listed in the constructor as well as the - subestimators contained within the `estimators` of the `ColumnwiseNB` + subestimators contained within the `nb_estimators` of the `ColumnwiseNB` instance. Parameters @@ -2091,7 +2077,6 @@ def get_params(self, deep=True): params : dict Parameter names mapped to their values. """ - # Implemented in the image and likeness of ColumnTranformer.get_params return self._get_params("_estimators", deep=deep) def set_params(self, **kwargs): @@ -2099,7 +2084,7 @@ def set_params(self, **kwargs): Valid parameter keys can be listed with `get_params()`. Note that you can directly set the parameters of the estimators contained in - `estimators` of `ColumnwiseNB`. + `nb_estimators` of `ColumnwiseNB`. Parameters ---------- @@ -2111,7 +2096,6 @@ def set_params(self, **kwargs): self : ColumnwiseNB This estimator. """ - # Implemented in the image and likeness of ColumnTranformer.set_params self._set_params("_estimators", **kwargs) return self From 75b2cfc87f0619b600b5a4435f1e1b1d6cc12e27 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:29:55 -0500 Subject: [PATCH 076/102] Correct mistake that fit() does not fit from scratch. Test --- sklearn/naive_bayes.py | 2 ++ sklearn/tests/test_naive_bayes.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 2c09277c0d4e6..9b13df190c9e4 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1998,6 +1998,8 @@ def fit(self, X, y, sample_weight=None): self : object Returns the instance itself. """ + if hasattr(self, "classes_"): + delattr(self, "classes_") return self._partial_fit( X, y, partial=False, classes=None, sample_weight=sample_weight ) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index a9a75c6f92dd1..059bb7cc71f7b 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1598,6 +1598,26 @@ def test_cwnb_partial_fit(): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) +def test_cwnb_fit_refits(): + # fit: re-fits the estimator de novo when called on a fitted estimator + clf1 = ColumnwiseNB( + nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf2 = ColumnwiseNB( + nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf1.fit(X2, y2) + clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) + clf2.fit(X2, y2) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + def test_cwnb_partial_fit_classes(): # partial_fit: error when classes are not provided at the first call clf1 = ColumnwiseNB( From a3a065375e18ccdb7c2f80e9c13c27a6dee57c59 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:51:19 -0500 Subject: [PATCH 077/102] Common tests and estimator checks for ColumnwiseNB (more) 1. estimator_checks.py can now build an instance of ColumnwiseNB 2. _select_half_first|second in naive_bayes.py are used as column selectors by the aforementioned instance. 3. X is converted to numpy ndarray, unless X is pandas DataFrame. This is to permit column indexing by str in pandas DataFrame, while passing tests when X is list of lists etc. Previously, I avoided any conversion of X. 4. The meta-estimator now does _check_n_features on predict. Previously, it was allowed to predict on more features than seen while fitting, because not all features could be selected. --- sklearn/naive_bayes.py | 33 ++++++++++++++++++++++++------- sklearn/tests/test_naive_bayes.py | 12 ++++++++--- sklearn/utils/estimator_checks.py | 9 +++++++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 9b13df190c9e4..886e9c8fd1699 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -32,7 +32,7 @@ from .utils.multiclass import _check_partial_fit_first_call from .utils.validation import check_is_fitted, check_non_negative from .utils.validation import _check_sample_weight -from .utils.validation import column_or_1d +from .utils.validation import column_or_1d, check_array from .utils.metaestimators import _BaseComposition, available_if from .utils import _safe_indexing, _get_column_indices, _print_elapsed_time, Bunch from .utils.fixes import delayed @@ -1544,6 +1544,16 @@ def _joint_log_likelihood(self, X): return total_ll +def _select_half_first(X): + """Column selector that selects the first half of columns""" + return list(range((X.shape[1] + 1) // 2)) + + +def _select_half_second(X): + """Column selector that selects the second half of columns""" + return list(range((X.shape[1] + 1) // 2, X.shape[1])) + + def _nb_estimators_have(attr): """Check if all self.nb_estimators or self.nb_estimators_ have attr. @@ -1755,15 +1765,23 @@ def __init__(self, nb_estimators, *, priors=None, n_jobs=None, verbose=False): def _check_X(self, X): """Validate X, used only in predict* methods.""" - # The meta-estimator checks for feature names only. Other checks - # and conversion to numpy array are performed by subestimators. - # It is important that X is not modified by the meta-estimator, and X's - # columns are passed to an estimator as they are. Note that estimators - # may modify (a copy of) X. E.g., BernoulliNB._check_X binarises the - # input. + # Defer conversion and validation of a pandas DataFrame to subestimators, + # in order to allow column indexing by str or int (if DataFrame). + # Convert other kinds here to allow column indexing by int (otherwise). + # Note that subestimators may modify (a copy of) X. For example, + # BernoulliNB._check_X binarises the input. + X = self._check_array_if_not_pandas(X) self._check_feature_names(X, reset=False) + self._check_n_features(X, reset=False) return X + def _check_array_if_not_pandas(self, array): + """Convert to ndarray, unless a pandas DataFrame""" + if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"): + return array + else: + return check_array(array) + def _joint_log_likelihood(self, X): """Calculate the meta-estimator's joint log-probability `log P(x,y)`.""" estimators = self._iter(fitted=True, replace_strings=True) @@ -1924,6 +1942,7 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): partial : bool, default=False True for partial_fit, False for fit. """ + X = self._check_array_if_not_pandas(X) first_call = not hasattr(self, "classes_") if first_call: # in fit() or the first call of partial_fit() self._validate_params() diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 059bb7cc71f7b..0a71e16879a73 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1065,12 +1065,18 @@ def test_cwnb_union_permutation(): clf1.fit(X2[:, [0, 1, 2, 3, 4]], y2) # (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) clf2.fit(X2[:, [2, 0, 1, 3, 4]], y2) # (0, 1, 2, 3, 4) <- (2, 0, 1, 3, 4) assert_array_almost_equal( - clf1.predict_proba(X2), clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), 8 + clf1.predict_proba(X2[:, [0, 1, 2, 3, 4]]), + clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), + 8, + ) + assert_array_almost_equal( + clf1.predict_log_proba(X2[:, [0, 1, 2, 3, 4]]), + clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), + 8, ) assert_array_almost_equal( - clf1.predict_log_proba(X2), clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), 8 + clf1.predict(X2[:, [0, 1, 2, 3, 4]]), clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8 ) - assert_array_almost_equal(clf1.predict(X2), clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8) def test_cwnb_estimators_pandas(): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 872f1b77eca9c..0c570e8b872f1 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -34,6 +34,7 @@ from ..linear_model import RANSACRegressor from ..linear_model import Ridge from ..linear_model import SGDRegressor +from ..naive_bayes import GaussianNB, _select_half_first, _select_half_second from ..base import ( clone, @@ -420,6 +421,14 @@ def _construct_instance(Estimator): ("est2", LogisticRegression(C=1)), ] ) + elif required_parameters in (["nb_estimators"],): + # ColumnwiseNB (naive Bayes meta-classifier) + estimator = Estimator( + nb_estimators=[ + ("gnb1", GaussianNB(var_smoothing=1e-13), _select_half_first), + ("gnb2", GaussianNB(), _select_half_second), + ] + ) else: msg = ( f"Can't instantiate estimator {Estimator.__name__} " From 5201045fbe8423ae2d0e3d0465c023da3ed7fa85 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 6 Jan 2023 19:49:58 -0500 Subject: [PATCH 078/102] Pass tests by removing memory address in column selector __repr__ Tests in CI pipeline return error: Different tests were collected between gw0 and gw1 For details and similar situation see:https://github.com/scikit-learn/scikit-learn/pull/18811#issuecomment-727226988 --- sklearn/naive_bayes.py | 24 ++++++++++++++++++------ sklearn/utils/estimator_checks.py | 6 +++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 886e9c8fd1699..212b25b94f5e0 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1544,14 +1544,26 @@ def _joint_log_likelihood(self, X): return total_ll -def _select_half_first(X): - """Column selector that selects the first half of columns""" - return list(range((X.shape[1] + 1) // 2)) +class _select_half: + """Column selector that selects the first half of columns + Used for testing purposes only. + """ + + def __init__(self, half="first"): + self.half = half -def _select_half_second(X): - """Column selector that selects the second half of columns""" - return list(range((X.shape[1] + 1) // 2, X.shape[1])) + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location. See: + # https://github.com/scikit-learn/scikit-learn/pull/18811#issuecomment-727226988 + return f'_select_half("{str(self.half)}")' + + def __call__(self, X): + if self.half == "first": + return list(range((X.shape[1] + 1) // 2)) + else: + return list(range((X.shape[1] + 1) // 2, X.shape[1])) def _nb_estimators_have(attr): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 0c570e8b872f1..629d1f8b4f923 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -34,7 +34,7 @@ from ..linear_model import RANSACRegressor from ..linear_model import Ridge from ..linear_model import SGDRegressor -from ..naive_bayes import GaussianNB, _select_half_first, _select_half_second +from ..naive_bayes import GaussianNB, _select_half from ..base import ( clone, @@ -425,8 +425,8 @@ def _construct_instance(Estimator): # ColumnwiseNB (naive Bayes meta-classifier) estimator = Estimator( nb_estimators=[ - ("gnb1", GaussianNB(var_smoothing=1e-13), _select_half_first), - ("gnb2", GaussianNB(), _select_half_second), + ("gnb1", GaussianNB(var_smoothing=1e-13), _select_half("first")), + ("gnb2", GaussianNB(), _select_half("second")), ] ) else: From 043dbd60b11a47b41e25ff9136ca8103f71ca95f Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Sun, 8 Jan 2023 12:19:34 -0500 Subject: [PATCH 079/102] Empty commit to trigger build pipeline From 6c5c7d816c7405138f7c1bb0759b83ee700a16d0 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Fri, 27 Jan 2023 17:08:57 -0500 Subject: [PATCH 080/102] Use utils.parallel.delayed not utils.fixes.delayed See PR #25242 for details --- sklearn/naive_bayes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 212b25b94f5e0..08545826b8bbc 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -21,7 +21,6 @@ import numpy as np from scipy.special import logsumexp -from joblib import Parallel from .base import BaseEstimator, ClassifierMixin from .base import clone @@ -35,7 +34,7 @@ from .utils.validation import column_or_1d, check_array from .utils.metaestimators import _BaseComposition, available_if from .utils import _safe_indexing, _get_column_indices, _print_elapsed_time, Bunch -from .utils.fixes import delayed +from .utils.parallel import delayed, Parallel from .utils._encode import _unique from .utils._estimator_html_repr import _VisualBlock from .compose._column_transformer import _is_empty_column_selection From a8a07c4fdd37a5ca78fb29bd52f88b183a7b511e Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 27 Mar 2023 00:22:43 -0400 Subject: [PATCH 081/102] TST use global_random_seed towards #22827 --- sklearn/tests/test_naive_bayes.py | 71 ++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index b8b6c3ce2295f..7b9e8d946dea3 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1030,9 +1030,10 @@ def test_cwnb_union_prior_gnb(): assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) -def test_cwnb_union_bnb_fit(): +def test_cwnb_union_bnb_fit(global_random_seed): # A union of BernoulliNB's yields the same prediction as a single BernoulliNB # (fit) + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) @@ -1041,12 +1042,17 @@ def test_cwnb_union_bnb_fit(): clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) - assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + # BernoulliNB is likely to yield 1/1 odds that eventually round to different + # y_pred as a result of unavoidable log/exp transformations done by ColumnwiseNB. + # We don't want to test for these discretisation discrepancies. + ii = abs(clf1.predict_proba(X1)[:, 0] - clf1.predict_proba(X1)[:, 1]) > 1e-8 + assert_array_almost_equal(clf1.predict(X1)[ii], clf2.predict(X1)[ii], 8) -def test_cwnb_union_bnb_partial_fit(): +def test_cwnb_union_bnb_partial_fit(global_random_seed): # A union of BernoulliNB's yields the same prediction as a single BernoulliNB # (partial_fit) + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) @@ -1056,11 +1062,16 @@ def test_cwnb_union_bnb_partial_fit(): clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) - assert_array_almost_equal(clf1.predict(X1), clf2.predict(X1), 8) + # BernoulliNB is likely to yield 1/1 odds that eventually round to different + # y_pred as a result of unavoidable log/exp transformations done by ColumnwiseNB. + # We don't want to test for these discretisation discrepancies. + ii = abs(clf1.predict_proba(X1)[:, 0] - clf1.predict_proba(X1)[:, 1]) > 1e-8 + assert_array_almost_equal(clf1.predict(X1)[ii], clf2.predict(X1)[ii], 8) -def test_cwnb_union_permutation(): +def test_cwnb_union_permutation(global_random_seed): # A union of several different NB's is permutation-invariant + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [3]), @@ -1168,8 +1179,9 @@ def test_cwnb_estimators_pandas(): ) -def test_cwnb_repeated_columns(): +def test_cwnb_repeated_columns(global_random_seed): # Subestimators spec: repeated col ints have the same effect as repeating data + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1, 1]), @@ -1189,8 +1201,9 @@ def test_cwnb_repeated_columns(): ) -def test_cwnb_empty_columns(): +def test_cwnb_empty_columns(global_random_seed): # Subestimators spec: empty cols have the same effect as an absent estimator + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1]), @@ -1226,8 +1239,9 @@ def test_cwnb_estimators_unique_names(): clf1.fit(X, y) -def test_cwnb_estimators_nonempty_list(): +def test_cwnb_estimators_nonempty_list(global_random_seed): # Subestimators spec: error on empty list + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf = ColumnwiseNB( nb_estimators=[], ) @@ -1324,8 +1338,9 @@ def predict(self, X): clf1.partial_fit(X, y) -def test_cwnb_estimators_setter(): +def test_cwnb_estimators_setter(global_random_seed): # _estimators setter works + X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] ) @@ -1367,12 +1382,13 @@ def test_cwnb_prior_valid_spec(): clf1.fit(X, y) -def test_cwnb_prior_match(): +def test_cwnb_prior_match(global_random_seed): # prior spec: all these ways work (and agree in our example) # (1) an array of values # (2a) a str name of a subestimator supporting class_prior_ # (2b) a str name of a subestimator supporting class_log_prior_ # (3) nothing (ColumnwiseNB will calculate relative frequencies) + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [0, 1]), @@ -1438,10 +1454,12 @@ def fit(self, X, y, sample_weight=None): clf.fit(X, y) -def test_cwnb_estimators_support_class_prior_mnb(): +def test_cwnb_estimators_support_class_prior_mnb(global_random_seed): # prior spec: error message when can't extract prior from subestimator # ColumnwiseNB tries both class_prior_ and class_log_prior, which is tested # in test_cwnb_prior_match() + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + class MultinomialNB_hide_log_prior(MultinomialNB): def fit(self, X, y, sample_weight=None): super().fit(X, y, sample_weight=None) @@ -1460,13 +1478,15 @@ def fit(self, X, y, sample_weight=None): clf.fit(X2, y2) -def test_cwnb_prior_nonzero(): +def test_cwnb_prior_nonzero(global_random_seed): # P(y)=0 in one or two subestimators results in P(y|x)=0 of meta-estimator. # Despite attempted Log[0], predicted class probabilities are all finite. # On a related note, meaningless results (including NaNs) may be produced # - if P(y)=0 in the meta-estimator, or/and # - if class priors differ across subestimators, # but this is not what is tested here. + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + rng = np.random.RandomState(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("g1", GaussianNB(), [1, 3, 5]), @@ -1521,8 +1541,9 @@ def test_cwnb_fit_sample_weight_ones(): assert_array_equal(clf1.predict(X), clf2.predict(X)) -def test_cwnb_partial_fit_sample_weight_ones(): +def test_cwnb_partial_fit_sample_weight_ones(global_random_seed): # weights in partial_fit have no effect if all ones + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( nb_estimators=[ @@ -1573,8 +1594,9 @@ def test_cwnb_fit_sample_weight_repeated(): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) -def test_cwnb_partial_fit_sample_weight_repeated(): +def test_cwnb_partial_fit_sample_weight_repeated(global_random_seed): # weights in partial_fit have the same effect as repeating data + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) weights = [1, 2, 3, 1, 4, 2] idx = list(chain(*([i] * w for i, w in enumerate(weights)))) clf1 = ColumnwiseNB( @@ -1600,8 +1622,9 @@ def test_cwnb_partial_fit_sample_weight_repeated(): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) -def test_cwnb_partial_fit(): +def test_cwnb_partial_fit(global_random_seed): # partial_fit: consecutive calls yield the same prediction as a single call + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) @@ -1620,8 +1643,9 @@ def test_cwnb_partial_fit(): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) -def test_cwnb_fit_refits(): +def test_cwnb_fit_refits(global_random_seed): # fit: re-fits the estimator de novo when called on a fitted estimator + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) @@ -1640,8 +1664,9 @@ def test_cwnb_fit_refits(): assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) -def test_cwnb_partial_fit_classes(): +def test_cwnb_partial_fit_classes(global_random_seed): # partial_fit: error when classes are not provided at the first call + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) @@ -1650,8 +1675,9 @@ def test_cwnb_partial_fit_classes(): clf1.partial_fit(X2, y2) -def test_cwnb_class_attributes_consistency(): +def test_cwnb_class_attributes_consistency(global_random_seed): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), @@ -1669,9 +1695,10 @@ def test_cwnb_class_attributes_consistency(): ) -def test_cwnb_params(): +def test_cwnb_params(global_random_seed): # Can get and set subestimators' parameters through name__paramname # clone() works on ColumnwiseNB + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), @@ -1694,8 +1721,9 @@ def test_cwnb_params(): assert id(clf2.named_estimators_["b1"]) != id(clf1.named_estimators_["b1"]) -def test_cwnb_n_jobs(): +def test_cwnb_n_jobs(global_random_seed): # n_jobs: same result whether with it or without + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( nb_estimators=[ ("b1", BernoulliNB(binarize=2), [1]), @@ -1740,9 +1768,10 @@ def test_cwnb_example(): clf.predict(X) -def test_cwnb_verbose(capsys): +def test_cwnb_verbose(capsys, global_random_seed): # Setting verbose=True does not result in an error. # This DOES NOT test if the desired output is generated. + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf = ColumnwiseNB( nb_estimators=[ ("mnb1", MultinomialNB(), [0, 1]), From 70f38fbe1bc4559d08fd2e5de4a63be4408fa5fb Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 27 Mar 2023 08:59:54 -0400 Subject: [PATCH 082/102] Trigger build From f8277ba615f90c874151d20e610f58d7c1f0fedc Mon Sep 17 00:00:00 2001 From: avm19 Date: Thu, 15 Jun 2023 00:04:28 -0400 Subject: [PATCH 083/102] Minor typo in a comment Co-authored-by: Albert Steppi --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index bf0a2fdfffee6..4db994d3b1fd4 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1906,7 +1906,7 @@ def _iter(self, *, fitted=False, replace_strings=False): def _update_class_prior(self): """Update class prior after most of the fitting as done.""" - if self.priors is None: # calculcate empirical prior from counts + if self.priors is None: # calculate empirical prior from counts priors = self.class_count_ / self.class_count_.sum() elif isinstance(self.priors, str): # extract prior from estimator name = self.priors From 4aa9b83a2d77551c21edc280daf9c355d1e80e8e Mon Sep 17 00:00:00 2001 From: avm19 Date: Thu, 15 Jun 2023 00:20:46 -0400 Subject: [PATCH 084/102] Add np.where to cover the possibility of zero prior As kindly suggested by @steppi Co-authored-by: Albert Steppi --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 4db994d3b1fd4..fadf6cc304a14 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1805,7 +1805,7 @@ def _joint_log_likelihood(self, X): ) n_estimators = len(all_jlls) log_prior = np.log(self.class_prior_) - return np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior + return np.where(np.isinf(log_prior), -np.inf, np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior) def _validate_estimators(self, check_partial=False): try: From d14dae93db03edfeeb5e8c58e2af023feaa17dc8 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 15 Jun 2023 00:55:15 -0400 Subject: [PATCH 085/102] black formatting --- sklearn/naive_bayes.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index fadf6cc304a14..fae19e6c97175 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -638,11 +638,9 @@ def _check_alpha(self): if _force_alpha == "warn" and alpha_min < alpha_lower_bound: _force_alpha = False warnings.warn( - ( - "The default value for `force_alpha` will change to `True` in 1.4." - " To suppress this warning, manually set the value of" - " `force_alpha`." - ), + "The default value for `force_alpha` will change to `True` in 1.4." + " To suppress this warning, manually set the value of" + " `force_alpha`.", FutureWarning, ) if alpha_min < alpha_lower_bound and not _force_alpha: @@ -1805,7 +1803,11 @@ def _joint_log_likelihood(self, X): ) n_estimators = len(all_jlls) log_prior = np.log(self.class_prior_) - return np.where(np.isinf(log_prior), -np.inf, np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior) + return np.where( + np.isinf(log_prior), + -np.inf, + np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior, + ) def _validate_estimators(self, check_partial=False): try: From 8ea37ffaffb768fd9040fe1c0e6f8d6b8a4b4eb9 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 15 Jun 2023 01:22:21 -0400 Subject: [PATCH 086/102] Reformatting to comply with black=23.3.0 --- sklearn/naive_bayes.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 68217d0517af1..8f7a788b1910e 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -638,9 +638,11 @@ def _check_alpha(self): if _force_alpha == "warn" and alpha_min < alpha_lower_bound: _force_alpha = False warnings.warn( - "The default value for `force_alpha` will change to `True` in 1.4." - " To suppress this warning, manually set the value of" - " `force_alpha`.", + ( + "The default value for `force_alpha` will change to `True` in 1.4." + " To suppress this warning, manually set the value of" + " `force_alpha`." + ), FutureWarning, ) if alpha_min < alpha_lower_bound and not _force_alpha: From dbbeaf5db2d5ccb6ef705c862af110d60980d479 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:38:06 -0400 Subject: [PATCH 087/102] Apply _fit_context decorator. Cf. #26473 --- sklearn/naive_bayes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 8f7a788b1910e..d6855ad4b6cf0 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -2011,6 +2011,10 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): self._update_class_prior() return self + @_fit_context( + # estimators in ColumnwiseNB.nb_estimators are not validated yet + prefer_skip_nested_validation=False + ) def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. From ff02ae8ce1dc619069aa3dc0354c677747ce862b Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:39:44 -0400 Subject: [PATCH 088/102] black formatting --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d6855ad4b6cf0..e0182d63df520 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -2014,7 +2014,7 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): @_fit_context( # estimators in ColumnwiseNB.nb_estimators are not validated yet prefer_skip_nested_validation=False - ) + ) def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. From cb5a43d0b10516ac78f8bb37abd3b03b463d289e Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 26 Jun 2023 11:45:58 -0400 Subject: [PATCH 089/102] Formatting ruff --- examples/miscellaneous/plot_combining_naive_bayes.py | 9 +++++---- sklearn/utils/estimator_checks.py | 3 +-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index 2f38b0d2f1e13..f617ba8621e62 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -25,6 +25,7 @@ # %% import pandas as pd + from sklearn import set_config from sklearn.datasets import fetch_openml @@ -42,12 +43,12 @@ # ------------------------------------------------ from sklearn.compose import ColumnTransformer -from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer -from sklearn.preprocessing import OrdinalEncoder -from sklearn.naive_bayes import GaussianNB, CategoricalNB, ColumnwiseNB -from sklearn.model_selection import train_test_split, GridSearchCV from sklearn.metrics import accuracy_score +from sklearn.model_selection import GridSearchCV, train_test_split +from sklearn.naive_bayes import CategoricalNB, ColumnwiseNB, GaussianNB +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OrdinalEncoder numeric_features = ["age", "fare"] numeric_transformer = SimpleImputer(strategy="median") diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c7c34a0acf216..b7d19c20bba8d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -46,9 +46,8 @@ from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale from ..random_projection import BaseRandomProjection -from ..utils._array_api import _convert_to_numpy +from ..utils._array_api import _convert_to_numpy, get_namespace from ..utils._array_api import device as array_device -from ..utils._array_api import get_namespace from ..utils._param_validation import ( InvalidParameterError, generate_invalid_param_val, From e2d48d8d7788531441153f42e09ff01f6e5fdd4b Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 19:53:15 -0400 Subject: [PATCH 090/102] Move the changelog entry from 1.3 to 1.4 --- doc/whats_new/v1.3.rst | 4 ---- doc/whats_new/v1.4.rst | 7 +++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 8e31f698e1050..12c471bd915c5 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -682,10 +682,6 @@ Changelog :mod:`sklearn.naive_bayes` .......................... -- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows - existing naive Bayes classifiers to be combined and applied to different columns - of `X`. :pr:`22574` by :user:`Andrey Melnik `. - - |Fix| :class:`naive_bayes.GaussianNB` does not raise anymore a `ZeroDivisionError` when the provided `sample_weight` reduces the problem to a single class in `fit`. diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 6531102bba9fd..21fd499023d83 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -223,6 +223,13 @@ Changelog - |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports Array API compatible inputs. :pr:`26855` by `Tim Head`_. +:mod:`sklearn.naive_bayes` +.......................... + +- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows + existing naive Bayes classifiers to be combined and applied to different columns + of `X`. :pr:`22574` by :user:`Andrey Melnik `. + :mod:`sklearn.neighbors` ........................ From 45e43819fffd3c3c6d09c7c3ccc5def35182688f Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:19:10 -0400 Subject: [PATCH 091/102] Move ColumnwiseNB section before Out of Core section --- doc/modules/naive_bayes.rst | 46 ++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 1dbb073b0e28d..b30989a1e194b 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -259,29 +259,6 @@ that all categories for each feature :math:`i` are represented with numbers :math:`0, ..., n_i - 1` where :math:`n_i` is the number of available categories of feature :math:`i`. -Out-of-core naive Bayes model fitting -------------------------------------- - -Naive Bayes models can be used to tackle large scale classification problems -for which the full training set might not fit in memory. To handle this case, -:class:`MultinomialNB`, :class:`BernoulliNB`, and :class:`GaussianNB` -expose a ``partial_fit`` method that can be used -incrementally as done with other classifiers as demonstrated in -:ref:`sphx_glr_auto_examples_applications_plot_out_of_core_classification.py`. All naive Bayes -classifiers support sample weighting. - -Contrary to the ``fit`` method, the first call to ``partial_fit`` needs to be -passed the list of all the expected class labels. - -For an overview of available strategies in scikit-learn, see also the -:ref:`out-of-core learning ` documentation. - -.. note:: - - The ``partial_fit`` method call of naive Bayes models introduces some - computational overhead. It is recommended to use data chunk sizes that are as - large as possible, that is as the available RAM allows. - .. _columnwise_naive_bayes: Mix and match naive Bayes models @@ -316,3 +293,26 @@ for an example of a mixed naive Bayes model implementation. See also :ref:`voting_classifier` for a way of combining general classifiers. An introduction to processing datasets with heterogeneous features is available at :ref:`column_transformer`. + +Out-of-core naive Bayes model fitting +------------------------------------- + +Naive Bayes models can be used to tackle large scale classification problems +for which the full training set might not fit in memory. To handle this case, +:class:`MultinomialNB`, :class:`BernoulliNB`, and :class:`GaussianNB` +expose a ``partial_fit`` method that can be used +incrementally as done with other classifiers as demonstrated in +:ref:`sphx_glr_auto_examples_applications_plot_out_of_core_classification.py`. All naive Bayes +classifiers support sample weighting. + +Contrary to the ``fit`` method, the first call to ``partial_fit`` needs to be +passed the list of all the expected class labels. + +For an overview of available strategies in scikit-learn, see also the +:ref:`out-of-core learning ` documentation. + +.. note:: + + The ``partial_fit`` method call of naive Bayes models introduces some + computational overhead. It is recommended to use data chunk sizes that are as + large as possible, that is as the available RAM allows. From 95d00cd89ad69a8d6f933d6bc17eca8404fa9d3e Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:47:25 -0400 Subject: [PATCH 092/102] Change versionadded from 1.3 to 1.4 --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 5a48bd764ab3a..5c84fa4cc7ef1 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1624,7 +1624,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Read more in the :ref:`User Guide `. - .. versionadded:: 1.3 + .. versionadded:: 1.4 Parameters ---------- From a1ed149c0a6aa8e06f6c2e83ed6fe575d809ab1b Mon Sep 17 00:00:00 2001 From: avm19 Date: Mon, 11 Sep 2023 20:47:55 -0400 Subject: [PATCH 093/102] Update sklearn/naive_bayes.py Co-authored-by: Guillaume Lemaitre --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 5a48bd764ab3a..07ffa6b904907 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1647,7 +1647,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Indexes the data on its second axis. Integers are interpreted as positional columns, while strings can reference DataFrame columns by name. A scalar string or int should be used where - `nb_estimator` expects X to be a 1d array-like (vector), + `naive_bayes_estimator` expects X to be a 1d array-like (vector), otherwise a 2d array will be passed to the transformer. A callable is passed the input data `X` and can return any of the above. To select multiple columns by name or dtype, you can use From 6c09328678c2ccc784c079092d97199ab40206ba Mon Sep 17 00:00:00 2001 From: avm19 Date: Mon, 11 Sep 2023 20:49:25 -0400 Subject: [PATCH 094/102] Update sklearn/naive_bayes.py Co-authored-by: Guillaume Lemaitre --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 07ffa6b904907..2f2371eeac62f 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1629,7 +1629,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Parameters ---------- nb_estimators : list of tuples - List of (name, nb_estimator, columns) tuples specifying the naive Bayes + List of `(name, naive_bayes_estimator, columns)` tuples specifying the naive Bayes estimators to be combined into a single naive Bayes meta-estimator. name : str From b0ccfb07c5ef2102124258ebb514582abf90d69b Mon Sep 17 00:00:00 2001 From: avm19 Date: Mon, 11 Sep 2023 20:54:34 -0400 Subject: [PATCH 095/102] Update sklearn/naive_bayes.py Co-authored-by: Guillaume Lemaitre --- sklearn/naive_bayes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 2f2371eeac62f..1ede11dc5449f 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1756,7 +1756,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): _required_parameters = ["nb_estimators"] _parameter_constraints = { - "nb_estimators": "no_validation", + "nb_estimators": [list], "priors": ["array-like", str, None], "n_jobs": [Integral, None], "verbose": ["verbose"], From 53fb1a912727480c64788143da1945048a09668c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:50:00 -0400 Subject: [PATCH 096/102] Formatting --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d9c74e1495354..c4188df17a0c5 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1629,8 +1629,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Parameters ---------- nb_estimators : list of tuples - List of `(name, naive_bayes_estimator, columns)` tuples specifying the naive Bayes - estimators to be combined into a single naive Bayes meta-estimator. + List of `(name, naive_bayes_estimator, columns)` tuples specifying the naive + Bayes estimators to be combined into a single naive Bayes meta-estimator. name : str Name of the naive Bayes estimator, by which the subestimator and From ddf5f554d628bd9b09b0b7463aaba7e6b25640d6 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:21:49 -0400 Subject: [PATCH 097/102] Fix test re _parameter_constraints = {'nb_estimators': [list], ...} --- sklearn/tests/test_naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index d024a6bec862f..56629a9617bbb 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1252,7 +1252,7 @@ def test_cwnb_estimators_nonempty_list(global_random_seed): clf = ColumnwiseNB( nb_estimators=None, ) - msg = "A list of naive Bayes estimators must be provided*" + msg = "The 'nb_estimators' parameter of ColumnwiseNB must be an instance of 'list'*" with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) @@ -1260,7 +1260,7 @@ def test_cwnb_estimators_nonempty_list(global_random_seed): clf = ColumnwiseNB( nb_estimators=GaussianNB(), ) - msg = "A list of naive Bayes estimators must be provided*" + msg = "The 'nb_estimators' parameter of ColumnwiseNB must be an instance of 'list'*" with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) From 78c70eb07bd73e35e929f49ece3a9685a15f248c Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:25:53 -0400 Subject: [PATCH 098/102] Change nb_estimator to naive_bayes_estimator in all docstring --- sklearn/naive_bayes.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index c4188df17a0c5..ccc4c72b91552 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1636,7 +1636,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Name of the naive Bayes estimator, by which the subestimator and its parameters can be set using :term:`set_params` and searched in grid search. - nb_estimator : estimator + naive_bayes_estimator : estimator The estimator must support :term:`fit` or :term:`partial_fit`, depending on how the meta-estimator is fitted. In addition, the estimator must support `predict_joint_log_proba` method, which @@ -1678,8 +1678,8 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): List of `(name, fitted_estimator, columns)` tuples, which follow the order of `estimators`. Here, `fitted_estimator` is a fitted naive Bayes estimator, except when `columns` presents an empty selection of - columns, in which case it is the original unfitted `nb_estimator`. If - the original specification of `columns` in `estimators` was a + columns, in which case it is the original unfitted `naive_bayes_estimator`. + If the original specification of `columns` in `estimators` was a callable, then `columns` is converted to a list of column indices. named_estimators_ : :class:`~sklearn.utils.Bunch` @@ -1813,7 +1813,7 @@ def _validate_estimators(self, check_partial=False): except (TypeError, AttributeError, ValueError) as exc: raise ValueError( "A list of naive Bayes estimators must be provided " - "in the form [(name, nb_estimator, columns), ... ]." + "in the form [(name, naive_bayes_estimator, columns), ... ]." ) from exc for e in estimators: if (not check_partial) and ( @@ -1847,7 +1847,7 @@ def _validate_column_callables(self, X): self._estimator_to_input_indices = estimator_to_input_indices def _iter(self, *, fitted=False, replace_strings=False): - """Generate `(name, nb_estimator, columns)` tuples. + """Generate `(name, naive_bayes_estimator, columns)` tuples. This is a private method, similar to ColumnTransformer._iter. Must not be called before _validate_column_callables. @@ -1869,7 +1869,7 @@ def _iter(self, *, fitted=False, replace_strings=False): Yields ------ tuple - of the form `(name, nb_estimator, columns)`. + of the form `(name, naive_bayes_estimator, columns)`. Notes ----- @@ -2016,9 +2016,9 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): def fit(self, X, y, sample_weight=None): """Fit the naive Bayes meta-estimator. - Calls `fit` of each subestimator `nb_estimator`. Only a corresponding - subset of columns of `X` is passed to each subestimator; `sample_weight` - and `y` are passed to the subestimators as they are. + Calls `fit` of each subestimator `naive_bayes_estimator`. + Only a corresponding subset of columns of `X` is passed to each subestimator; + `sample_weight` and `y` are passed to the subestimators as they are. Parameters ---------- From 95c9c69ae2ee52bd618bcc542a60c707e23f21d1 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:03:48 -0400 Subject: [PATCH 099/102] Rename 'nb_estimators' into 'estimators' --- .../plot_combining_naive_bayes.py | 2 +- sklearn/naive_bayes.py | 43 +++--- sklearn/tests/test_metaestimators.py | 1 + sklearn/tests/test_naive_bayes.py | 144 +++++++++--------- sklearn/utils/estimator_checks.py | 24 +-- 5 files changed, 109 insertions(+), 105 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index f617ba8621e62..a864cde64c7f3 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -66,7 +66,7 @@ ) classifier = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("gnb", GaussianNB(), numeric_features), ("cnb", CategoricalNB(), categorical_features), ] diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index ccc4c72b91552..e4f6e3121a124 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1565,19 +1565,19 @@ def __call__(self, X): def _nb_estimators_have(attr): - """Check if all self.nb_estimators or self.nb_estimators_ have attr. + """Check if all self.estimators or self.nb_estimators_ have attr. Used together with `available_if` in `ColumnwiseNB`.""" # This function is used with `_available_if` before validation. # The try statement suppresses errors caused by incorrect specification of - # self.nb_estimators. Informative errors are raised at validation elsewhere. + # self.estimators. Informative errors are raised at validation elsewhere. def chk(obj): try: if hasattr(obj, "nb_estimators_"): out = all(hasattr(triplet[1], attr) for triplet in obj.nb_estimators_) else: - out = all(hasattr(triplet[1], attr) for triplet in obj.nb_estimators) + out = all(hasattr(triplet[1], attr) for triplet in obj.estimators) except (TypeError, IndexError, AttributeError): return False return out @@ -1628,7 +1628,7 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): Parameters ---------- - nb_estimators : list of tuples + estimators : list of tuples List of `(name, naive_bayes_estimator, columns)` tuples specifying the naive Bayes estimators to be combined into a single naive Bayes meta-estimator. @@ -1742,21 +1742,21 @@ class ColumnwiseNB(_BaseNB, _BaseComposition): >>> X = rng.randint(5, size=(6, 100)) >>> y = np.array([0, 0, 1, 1, 2, 2]) >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB - >>> clf = ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), + >>> clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), ... ('mnb2', MultinomialNB(), [3, 4]), ... ('gnb1', GaussianNB(), [5])]) >>> clf.fit(X, y) - ColumnwiseNB(nb_estimators=[('mnb1', MultinomialNB(), [0, 1]), + ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), ('mnb2', MultinomialNB(), [3, 4]), ('gnb1', GaussianNB(), [5])]) >>> print(clf.predict(X)) [0 0 1 0 2 2] """ - _required_parameters = ["nb_estimators"] + _required_parameters = ["estimators"] _parameter_constraints = { - "nb_estimators": [list], + "estimators": [list], "priors": ["array-like", str, None], "n_jobs": [Integral, None], "verbose": ["verbose"], @@ -1767,8 +1767,8 @@ def _log_message(self, name, idx, total): return None return f"({idx} of {total}) Processing {name}" - def __init__(self, nb_estimators, *, priors=None, n_jobs=None, verbose=False): - self.nb_estimators = nb_estimators + def __init__(self, estimators, *, priors=None, n_jobs=None, verbose=False): + self.estimators = estimators self.priors = priors self.n_jobs = n_jobs self.verbose = verbose @@ -1809,7 +1809,7 @@ def _joint_log_likelihood(self, X): def _validate_estimators(self, check_partial=False): try: - names, estimators, _ = zip(*self.nb_estimators) + names, estimators, _ = zip(*self.estimators) except (TypeError, AttributeError, ValueError) as exc: raise ValueError( "A list of naive Bayes estimators must be provided " @@ -1838,7 +1838,7 @@ def _validate_column_callables(self, X): """ all_columns = [] estimator_to_input_indices = {} - for name, _, columns in self.nb_estimators: + for name, _, columns in self.estimators: if callable(columns): columns = columns(X) all_columns.append(columns) @@ -1898,7 +1898,7 @@ def _iter(self, *, fitted=False, replace_strings=False): else: yield (name, estimator, cols) else: # fitted=False - for (name, estimator, _), cols in zip(self.nb_estimators, self._columns): + for (name, estimator, _), cols in zip(self.estimators, self._columns): if replace_strings and _is_empty_column_selection(cols): continue else: @@ -2010,7 +2010,7 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): return self @_fit_context( - # estimators in ColumnwiseNB.nb_estimators are not validated yet + # estimators in ColumnwiseNB.estimators are not validated yet prefer_skip_nested_validation=False ) def fit(self, X, y, sample_weight=None): @@ -2084,25 +2084,24 @@ def _estimators(self): which expects lists of tuples of len 2. """ try: - return [(name, e) for name, e, _ in self.nb_estimators] + return [(name, e) for name, e, _ in self.estimators] except (TypeError, ValueError): # This try-except clause is needed to pass the test from test_common.py: # test_estimators_do_not_raise_errors_in_init_or_set_params(). # ColumnTransformer does the same. See PR #21355 for details. - return self.nb_estimators + return self.estimators @_estimators.setter def _estimators(self, value): - self.nb_estimators = [ - (name, e, col) - for ((name, e), (_, _, col)) in zip(value, self.nb_estimators) + self.estimators = [ + (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimators) ] def get_params(self, deep=True): """Get parameters for this estimator. Returns the parameters listed in the constructor as well as the - subestimators contained within the `nb_estimators` of the `ColumnwiseNB` + subestimators contained within the `estimators` of the `ColumnwiseNB` instance. Parameters @@ -2123,7 +2122,7 @@ def set_params(self, **kwargs): Valid parameter keys can be listed with `get_params()`. Note that you can directly set the parameters of the estimators contained in - `nb_estimators` of `ColumnwiseNB`. + `estimators` of `ColumnwiseNB`. Parameters ---------- @@ -2140,7 +2139,7 @@ def set_params(self, **kwargs): def _sk_visual_block_(self): """HTML representation of this estimator.""" - names, estimators, name_details = zip(*self.nb_estimators) + names, estimators, name_details = zip(*self.estimators) return _VisualBlock( "parallel", estimators, names=names, name_details=name_details ) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index b3c6820faefc2..ae1569d85f5de 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -255,6 +255,7 @@ def _generate_meta_estimator_instances_with_pipeline(): "BaggingClassifier", "BaggingRegressor", "ClassifierChain", # data validation is necessary + "ColumnwiseNB", "IterativeImputer", "OneVsOneClassifier", # input validation can't be avoided "RANSACRegressor", diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 56629a9617bbb..6f94bf4fe0066 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1000,7 +1000,7 @@ def test_predict_joint_proba(Estimator, global_random_seed): def test_cwnb_union_gnb(): # A union of GaussianNB's yields the same prediction as a single GaussianNB clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] + estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] ) clf2 = GaussianNB() clf1.fit(X, y) @@ -1015,7 +1015,7 @@ def test_cwnb_union_prior_gnb(): # when class priors are provided by user priors = np.array([1 / 3, 2 / 3]) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(priors=priors), [0]), ("g2", GaussianNB(priors=priors), [1]), ], @@ -1034,7 +1034,7 @@ def test_cwnb_union_bnb_fit(global_random_seed): # (fit) X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.fit(X1, y1) @@ -1053,7 +1053,7 @@ def test_cwnb_union_bnb_partial_fit(global_random_seed): # (partial_fit) X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] ) clf2 = BernoulliNB() clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) @@ -1072,7 +1072,7 @@ def test_cwnb_union_permutation(global_random_seed): # A union of several different NB's is permutation-invariant X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [0]), ("m1", MultinomialNB(), [0, 2]), @@ -1081,7 +1081,7 @@ def test_cwnb_union_permutation(global_random_seed): ) # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [3]), ("g1", GaussianNB(), [1]), ("m1", MultinomialNB(), [1, 0]), @@ -1112,10 +1112,10 @@ def test_cwnb_estimators_pandas(): # Subestimators spec: cols can be lists of int or lists of str, if DataFrame clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), ["col1"]), ("g2", GaussianNB(), ["col0", "col1"]), ] @@ -1134,14 +1134,14 @@ def test_cwnb_estimators_pandas(): # when callable columns produce the empty set. select_none = make_column_selector(pattern="qwerasdf") clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), select_none), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(Xdf, y) clf2.fit(Xdf, y) @@ -1149,10 +1149,10 @@ def test_cwnb_estimators_pandas(): clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 ) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.nb_estimators) == len(clf1.estimators_) == 3 - assert len(clf2.nb_estimators) == len(clf2.estimators_) == 2 + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.nb_estimators[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) # Subestimators spec: test callable columns select_int = make_column_selector(dtype_include=np.int_) @@ -1160,13 +1160,13 @@ def test_cwnb_estimators_pandas(): Xdf2 = Xdf Xdf2["col3"] = np.exp(Xdf["col0"]) - 0.5 * Xdf["col1"] clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), ["col3"]), ("m1", BernoulliNB(), ["col0", "col1"]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), select_float), ("g2", BernoulliNB(), select_int), ] @@ -1182,13 +1182,13 @@ def test_cwnb_repeated_columns(global_random_seed): # Subestimators spec: repeated col ints have the same effect as repeating data X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [1, 1]), ("b1", BernoulliNB(), [0, 0, 1, 1]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0, 1]), ("b1", BernoulliNB(), [2, 3, 4, 5]), ] @@ -1204,36 +1204,36 @@ def test_cwnb_empty_columns(global_random_seed): # Subestimators spec: empty cols have the same effect as an absent estimator X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB(), []), ("g3", GaussianNB(), [0, 1]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] ) clf1.fit(X1, y1) clf2.fit(X1, y1) assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) # Empty-columns estimators are passed to estimators_ and the numbers match - assert len(clf1.nb_estimators) == len(clf1.estimators_) == 3 - assert len(clf2.nb_estimators) == len(clf2.estimators_) == 2 + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 # No cloning of the empty-columns estimators took place: - assert id(clf1.nb_estimators[1][1]) == id(clf1.named_estimators_["g2"]) + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) def test_cwnb_estimators_unique_names(): # Subestimators spec: error on repeated names clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] ) msg = "Names provided are not unique" with pytest.raises(ValueError, match=msg): clf1.fit(X, y) clf1 = ColumnwiseNB( - nb_estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] + estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] ) clf1.fit(X, y) @@ -1242,7 +1242,7 @@ def test_cwnb_estimators_nonempty_list(global_random_seed): # Subestimators spec: error on empty list X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf = ColumnwiseNB( - nb_estimators=[], + estimators=[], ) msg = "A list of naive Bayes estimators must be provided*" with pytest.raises(ValueError, match=msg): @@ -1250,17 +1250,17 @@ def test_cwnb_estimators_nonempty_list(global_random_seed): # Subestimators spec: error on None clf = ColumnwiseNB( - nb_estimators=None, + estimators=None, ) - msg = "The 'nb_estimators' parameter of ColumnwiseNB must be an instance of 'list'*" + msg = "The 'estimators' parameter of ColumnwiseNB must be an instance of 'list'*" with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) # Subestimators spec: error on non-tuple clf = ColumnwiseNB( - nb_estimators=GaussianNB(), + estimators=GaussianNB(), ) - msg = "The 'nb_estimators' parameter of ColumnwiseNB must be an instance of 'list'*" + msg = "The 'estimators' parameter of ColumnwiseNB must be an instance of 'list'*" with pytest.raises(ValueError, match=msg): clf.fit(X1, y1) @@ -1281,7 +1281,7 @@ def partial_fit(self, X, y): def predict(self, X): pass - clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.partial_fit(X, y) @@ -1303,13 +1303,13 @@ def predict_joint_log_proba(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.fit(X, y) delattr(notNB, "predict_joint_log_proba") - clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "Estimators must be .aive Bayes estimators implementing *" with pytest.raises(TypeError, match=msg): clf1.fit(X, y) @@ -1331,7 +1331,7 @@ def predict_joint_log_proba(self, X): def predict(self, X): pass - clf1 = ColumnwiseNB(nb_estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) msg = "This 'ColumnwiseNB' has no attribute 'partial_fit'*" with pytest.raises(AttributeError, match=msg): clf1.partial_fit(X, y) @@ -1341,23 +1341,23 @@ def test_cwnb_estimators_setter(global_random_seed): # _estimators setter works X1, y1 = get_random_normal_x_binary_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] + estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] ) clf1.fit(X1, y1) clf1._estimators = [ ("x1", clf1.named_estimators_["g1"]), ("x2", clf1.named_estimators_["g1"]), ] - assert clf1.nb_estimators[0][0] == "x1" - assert clf1.nb_estimators[0][1] is clf1.named_estimators_["g1"] - assert clf1.nb_estimators[1][0] == "x2" - assert clf1.nb_estimators[1][1] is clf1.named_estimators_["g1"] + assert clf1.estimators[0][0] == "x1" + assert clf1.estimators[0][1] is clf1.named_estimators_["g1"] + assert clf1.estimators[1][0] == "x2" + assert clf1.estimators[1][1] is clf1.named_estimators_["g1"] def test_cwnb_prior_valid_spec(): # prior spec: error when negative, sum!=1 or bad length clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([-0.25, 1.25]), ) msg = "Priors must be non-negative." @@ -1365,7 +1365,7 @@ def test_cwnb_prior_valid_spec(): clf1.fit(X, y) clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.7]), ) msg = "The sum of the priors should be 1." @@ -1373,7 +1373,7 @@ def test_cwnb_prior_valid_spec(): clf1.fit(X, y) clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], priors=np.array([0.25, 0.25, 0.25, 0.25]), ) msg = "Number of priors must match number of classes." @@ -1389,28 +1389,28 @@ def test_cwnb_prior_match(global_random_seed): # (3) nothing (ColumnwiseNB will calculate relative frequencies) X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], priors=np.array([1 / 3, 1 / 3, 1 / 3]), # prior is provided by user ) clf2a = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], priors="g1", # prior will be estimated by sub-estimator "g1" ) clf2b = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], priors="m1", # prior will be estimated by sub-estimator "m1" ) clf3 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0, 1]), ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), ], # prior will be estimated by the meta-estimator @@ -1442,7 +1442,7 @@ def fit(self, X, y, sample_weight=None): del self.class_prior_ clf = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [1]), ("g2", GaussianNB_hide_prior(), [0, 1]), ], @@ -1466,7 +1466,7 @@ def fit(self, X, y, sample_weight=None): del self.class_log_prior_ clf = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [0]), ("m1", MultinomialNB_hide_log_prior(), [1, 2, 3, 4, 5]), ], @@ -1487,7 +1487,7 @@ def test_cwnb_prior_nonzero(global_random_seed): X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) rng = np.random.RandomState(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), ] @@ -1505,7 +1505,7 @@ def test_cwnb_prior_nonzero(global_random_seed): assert np.isfinite(p).all() clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), ("g2", GaussianNB(priors=np.array([0.5, 0.5, 0])), [0, 1]), ] @@ -1526,10 +1526,10 @@ def test_cwnb_fit_sample_weight_ones(): # weights in fit have no effect if all ones weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf2 = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) clf1.fit(X, y, sample_weight=weights) clf2.fit(X, y) @@ -1545,13 +1545,13 @@ def test_cwnb_partial_fit_sample_weight_ones(global_random_seed): X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) weights = [1, 1, 1, 1, 1, 1] clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1571,13 +1571,13 @@ def test_cwnb_fit_sample_weight_repeated(): idx = list(chain(*([i] * w for i, w in enumerate(weights)))) # var_smoothing=0.0 is for maximum precision in dealing with a small sample clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("g1", GaussianNB(var_smoothing=0.0), [1]), ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), ] @@ -1599,13 +1599,13 @@ def test_cwnb_partial_fit_sample_weight_repeated(global_random_seed): weights = [1, 2, 3, 1, 4, 2] idx = list(chain(*([i] * w for i, w in enumerate(weights)))) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1625,10 +1625,10 @@ def test_cwnb_partial_fit(global_random_seed): # partial_fit: consecutive calls yield the same prediction as a single call X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf2 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf1.partial_fit(X2, y2, classes=np.unique(y2)) clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) @@ -1646,10 +1646,10 @@ def test_cwnb_fit_refits(global_random_seed): # fit: re-fits the estimator de novo when called on a fitted estimator X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf2 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) clf1.fit(X2, y2) clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) @@ -1667,7 +1667,7 @@ def test_cwnb_partial_fit_classes(global_random_seed): # partial_fit: error when classes are not provided at the first call X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] ) msg = ".lasses must be passed on the first call to partial_fit" with pytest.raises(ValueError, match=msg): @@ -1678,7 +1678,7 @@ def test_cwnb_class_attributes_consistency(global_random_seed): # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), ] @@ -1699,7 +1699,7 @@ def test_cwnb_params(global_random_seed): # clone() works on ColumnwiseNB X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), ("m1", MultinomialNB(class_prior=[0.2, 0.2, 0.6]), [0, 2, 3]), ] @@ -1710,8 +1710,8 @@ def test_cwnb_params(global_random_seed): assert p["b1__binarize"] == 2 assert p["m1__class_prior"] == [0.2, 0.2, 0.6] clf1.set_params(b1__alpha=123, m1__class_prior=[0.3, 0.3, 0.4]) - assert clf1.nb_estimators[0][1].alpha == 123 - assert_array_equal(clf1.nb_estimators[1][1].class_prior, [0.3, 0.3, 0.4]) + assert clf1.estimators[0][1].alpha == 123 + assert_array_equal(clf1.estimators[1][1].class_prior, [0.3, 0.3, 0.4]) # After cloning and fitting, we can check through named_estimators, which # maps to fitted estimators_: clf2 = clone(clf1).fit(X2, y2) @@ -1724,7 +1724,7 @@ def test_cwnb_n_jobs(global_random_seed): # n_jobs: same result whether with it or without X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf1 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1733,7 +1733,7 @@ def test_cwnb_n_jobs(global_random_seed): n_jobs=4, ) clf2 = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("b1", BernoulliNB(binarize=2), [1]), ("b2", BernoulliNB(binarize=2), [1]), ("m1", MultinomialNB(), [0, 2, 3]), @@ -1757,7 +1757,7 @@ def test_cwnb_example(): y = np.array([0, 0, 1, 1, 2, 2]) clf = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("mnb1", MultinomialNB(), [0, 1]), ("mnb2", MultinomialNB(), [3, 4]), ("gnb1", GaussianNB(), [5]), @@ -1772,7 +1772,7 @@ def test_cwnb_verbose(capsys, global_random_seed): # This DOES NOT test if the desired output is generated. X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) clf = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("mnb1", MultinomialNB(), [0, 1]), ("mnb2", MultinomialNB(), [3, 4]), ("gnb1", GaussianNB(), [5]), @@ -1788,7 +1788,7 @@ def test_cwnb_sk_visual_block(capsys): # visual block representation correctly extracts names, cols and estimators estimators = (MultinomialNB(), MultinomialNB(), GaussianNB()) clf = ColumnwiseNB( - nb_estimators=[ + estimators=[ ("mnb1", estimators[0], [0, 1]), ("mnb2", estimators[1], [3, 4]), ("gnb1", estimators[2], [5]), @@ -1806,6 +1806,6 @@ def test_cwnb_check_param_validation(): # create an instance of ColumnwiseNB (also of some other estimators, such as # ColumnTransformer and Pipeline). clf = ColumnwiseNB( - nb_estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] ) check_param_validation("ColumnwiseNB", clf) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index a4a20a7ee0f05..afae45598ac31 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -40,7 +40,7 @@ from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel from ..model_selection import ShuffleSplit, train_test_split from ..model_selection._validation import _safe_split -from ..naive_bayes import GaussianNB, _select_half +from ..naive_bayes import ColumnwiseNB, GaussianNB, _select_half from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale from ..random_projection import BaseRandomProjection @@ -425,8 +425,20 @@ def _construct_instance(Estimator): else: estimator = Estimator(LogisticRegression(C=1)) elif required_parameters in (["estimators"],): + if issubclass(Estimator, ColumnwiseNB): + # ColumnwiseNB (naive Bayes meta-classifier) + estimator = Estimator( + estimators=[ + ( + "gnb1", + GaussianNB(var_smoothing=1e-13), + _select_half("first"), + ), + ("gnb2", GaussianNB(), _select_half("second")), + ] + ) # Heterogeneous ensemble classes (i.e. stacking, voting) - if issubclass(Estimator, RegressorMixin): + elif issubclass(Estimator, RegressorMixin): estimator = Estimator( estimators=[("est1", Ridge(alpha=0.1)), ("est2", Ridge(alpha=1))] ) @@ -437,14 +449,6 @@ def _construct_instance(Estimator): ("est2", LogisticRegression(C=1)), ] ) - elif required_parameters in (["nb_estimators"],): - # ColumnwiseNB (naive Bayes meta-classifier) - estimator = Estimator( - nb_estimators=[ - ("gnb1", GaussianNB(var_smoothing=1e-13), _select_half("first")), - ("gnb2", GaussianNB(), _select_half("second")), - ] - ) else: msg = ( f"Can't instantiate estimator {Estimator.__name__} " From 3c3315c46d5d9264aeaa6c79c906da5302f590d2 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:28:39 -0400 Subject: [PATCH 100/102] Fix rename 'nb_estimators' into 'estimators' --- examples/miscellaneous/plot_combining_naive_bayes.py | 4 ++-- sklearn/naive_bayes.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py index a864cde64c7f3..6441ab1293a9d 100644 --- a/examples/miscellaneous/plot_combining_naive_bayes.py +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -90,7 +90,7 @@ # of other hyperparameters with the help of :class:`~.model_selection.GridSearchCV`. param_grid = { - "classifier__nb_estimators": [ + "classifier__estimators": [ [ ("gnb", GaussianNB(), ["age", "fare"]), ("cnb", CategoricalNB(), categorical_features), @@ -119,7 +119,7 @@ cv_results = pd.DataFrame(grid_search.cv_results_) cv_results = cv_results.sort_values("mean_test_score", ascending=False) -cv_results["Columns dictionary"] = cv_results["param_classifier__nb_estimators"].map( +cv_results["Columns dictionary"] = cv_results["param_classifier__estimators"].map( lambda l: {e[0]: e[-1] for e in l} ) cv_results["'gnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["gnb"]) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index e4f6e3121a124..f7ff1a676ab79 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1564,8 +1564,8 @@ def __call__(self, X): return list(range((X.shape[1] + 1) // 2, X.shape[1])) -def _nb_estimators_have(attr): - """Check if all self.estimators or self.nb_estimators_ have attr. +def _estimators_have(attr): + """Check if all self.estimators or self.estimators_ have attr. Used together with `available_if` in `ColumnwiseNB`.""" @@ -1574,8 +1574,8 @@ def _nb_estimators_have(attr): # self.estimators. Informative errors are raised at validation elsewhere. def chk(obj): try: - if hasattr(obj, "nb_estimators_"): - out = all(hasattr(triplet[1], attr) for triplet in obj.nb_estimators_) + if hasattr(obj, "estimators_"): + out = all(hasattr(triplet[1], attr) for triplet in obj.estimators_) else: out = all(hasattr(triplet[1], attr) for triplet in obj.estimators) except (TypeError, IndexError, AttributeError): @@ -2041,7 +2041,7 @@ def fit(self, X, y, sample_weight=None): X, y, partial=False, classes=None, sample_weight=sample_weight ) - @available_if(_nb_estimators_have("partial_fit")) + @available_if(_estimators_have("partial_fit")) def partial_fit(self, X, y, classes=None, sample_weight=None): """Fit incrementally the naive Bayes meta-estimator on a batch of samples. From ca78f52972e6a349acfa2da26985cc9d818b9ac7 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:54:54 -0400 Subject: [PATCH 101/102] Fix _fit_context decorator. --- sklearn/naive_bayes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index f7ff1a676ab79..dd55cfef03fbe 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1959,7 +1959,6 @@ def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): X = self._check_array_if_not_pandas(X) first_call = not hasattr(self, "classes_") if first_call: # in fit() or the first call of partial_fit() - self._validate_params() self._check_feature_names(X, reset=True) self._check_n_features(X, reset=True) self._validate_estimators(check_partial=partial) @@ -2042,6 +2041,10 @@ def fit(self, X, y, sample_weight=None): ) @available_if(_estimators_have("partial_fit")) + @_fit_context( + # estimators in ColumnwiseNB.estimators are not validated yet + prefer_skip_nested_validation=False + ) def partial_fit(self, X, y, classes=None, sample_weight=None): """Fit incrementally the naive Bayes meta-estimator on a batch of samples. From 7d0ad34285ab8dec7c67b399e1493d0b709b3827 Mon Sep 17 00:00:00 2001 From: avm19 <52547519avm19@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:03:29 -0400 Subject: [PATCH 102/102] Change _iter signature to mirror ColumnTransformer changes in #27005 --- sklearn/naive_bayes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index dd55cfef03fbe..eb7b8ee7de2bd 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -1846,7 +1846,7 @@ def _validate_column_callables(self, X): self._columns = all_columns self._estimator_to_input_indices = estimator_to_input_indices - def _iter(self, *, fitted=False, replace_strings=False): + def _iter(self, *, fitted, replace_strings): """Generate `(name, naive_bayes_estimator, columns)` tuples. This is a private method, similar to ColumnTransformer._iter. @@ -1942,7 +1942,7 @@ def _update_fitted_estimators(self, fitted_estimators): estimators_ = [] fitted_estimators = iter(fitted_estimators) - for name, nb_estimator, cols in self._iter(): + for name, nb_estimator, cols in self._iter(fitted=False, replace_strings=False): if not _is_empty_column_selection(cols): updated_nb_estimator = next(fitted_estimators) else: # don't advance fitted_estimators; use original