diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 7f6f5c910a3fc..ea6f44be5dcf4 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1317,6 +1317,7 @@ Visualization naive_bayes.ComplementNB naive_bayes.GaussianNB naive_bayes.MultinomialNB + naive_bayes.ColumnwiseNB .. _neighbors_ref: diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index b0b32c28e455a..b30989a1e194b 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -259,6 +259,41 @@ that all categories for each feature :math:`i` are represented with numbers :math:`0, ..., n_i - 1` where :math:`n_i` is the number of available categories of feature :math:`i`. +.. _columnwise_naive_bayes: + +Mix and match naive Bayes models +-------------------------------- + +A naive Bayes model that assumes different distribution families for different +features (or subsets of features) can be constructed using :class:`ColumnwiseNB`. +It is a meta-estimator, whose operation relies on naive Bayes +sub-estimators, which can be chosen in any number or combination from +:class:`GaussianNB`, :class:`MultinomialNB`, :class:`ComplementNB`, +:class:`BernoulliNB`, :class:`CategoricalNB`, and user-defined models +(provided they implement necessary methods). + +When creating a :class:`ColumnwiseNB` estimator, one specifies sub-estimators +and their respective column subsets as a list of tuples. +Each sub-estimator is fitted and evaluated independently of the +others and "sees" only the features assigned to it. The predictions of sub-estimators are +combined via + +.. math:: + + \log P(x,y)=\log P(x_{1},y) + \dots + \log P(x_{M},y) - (M - 1)\log P(y), + +where :math:`\log P(x,y)` is the joint log-probability predicted by the meta-estimator, +:math:`\log P(x_{m},y)` is that by the :math:`m` th sub-estimator, +:math:`\log P(y)` is the class prior used by the meta-estimator, and +:math:`M\geq1` is the total number of sub-estimators. + +See :ref:`sphx_glr_auto_examples_miscellaneous_plot_combining_naive_bayes.py` +for an example of a mixed naive Bayes model implementation. + +See also :ref:`voting_classifier` for a way of combining general classifiers. +An introduction to processing datasets with heterogeneous features is available at +:ref:`column_transformer`. + Out-of-core naive Bayes model fitting ------------------------------------- diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index f85e5604d2622..244cf8a7521ea 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -237,6 +237,13 @@ Changelog - |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports Array API compatible inputs. :pr:`26855` by `Tim Head`_. +:mod:`sklearn.naive_bayes` +.......................... + +- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows + existing naive Bayes classifiers to be combined and applied to different columns + of `X`. :pr:`22574` by :user:`Andrey Melnik `. + :mod:`sklearn.neighbors` ........................ diff --git a/examples/miscellaneous/plot_combining_naive_bayes.py b/examples/miscellaneous/plot_combining_naive_bayes.py new file mode 100644 index 0000000000000..6441ab1293a9d --- /dev/null +++ b/examples/miscellaneous/plot_combining_naive_bayes.py @@ -0,0 +1,135 @@ +""" +=================================================== +Combining Naive Bayes Estimators using ColumnwiseNB +=================================================== + +.. currentmodule:: sklearn + +This example shows how to use :class:`~naive_bayes.ColumnwiseNB` +meta-estimator to construct a naive Bayes model from base naive Bayes +estimators. The resulting model is applied to a dataset with a mixture of +discrete and continuous features. + +We consider the titanic dataset, in which: + +- numerical (continous) features "age" and "fare" are handled by + :class:`~naive_bayes.GaussianNB`; +- categorical (discrete) features "embarked", "sex", and "pclass" are handled + by :class:`~naive_bayes.CategoricalNB`. +""" + +# Author: Andrey V. Melnik +# Pedro Morales +# +# License: BSD 3 clause + +# %% +import pandas as pd + +from sklearn import set_config +from sklearn.datasets import fetch_openml + +set_config(transform_output="pandas") + +X, y = fetch_openml( + "titanic", version=1, as_frame=True, return_X_y=True, n_retries=10, parser="auto" +) +X["pclass"] = X["pclass"].astype("category") +# Add a category for NaNs to the "embarked" feature: +X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A") + +# %% +# Build and use a pipeline around ``ColumnwiseNB`` +# ------------------------------------------------ + +from sklearn.compose import ColumnTransformer +from sklearn.impute import SimpleImputer +from sklearn.metrics import accuracy_score +from sklearn.model_selection import GridSearchCV, train_test_split +from sklearn.naive_bayes import CategoricalNB, ColumnwiseNB, GaussianNB +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OrdinalEncoder + +numeric_features = ["age", "fare"] +numeric_transformer = SimpleImputer(strategy="median") + +categorical_features = ["embarked", "sex", "pclass"] +categories = [X[c].unique().to_list() for c in X[categorical_features]] +categorical_transformer = OrdinalEncoder(categories=categories) + +preprocessor = ColumnTransformer( + transformers=[ + ("num", numeric_transformer, numeric_features), + ("cat", categorical_transformer, categorical_features), + ], + verbose_feature_names_out=False, +) + +classifier = ColumnwiseNB( + estimators=[ + ("gnb", GaussianNB(), numeric_features), + ("cnb", CategoricalNB(), categorical_features), + ] +) + +pipe = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", classifier)]) +pipe +# %% +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0) + +pipe.fit(X_train, y_train) +y_pred = pipe.predict(X_test) +print(f"Test accuracy: {accuracy_score(y_test, y_pred)}") + +# %% +# Compare choices of columns using ``GridSearchCV`` +# -------------------------------------------------- +# +# The allocation of columns to constituent subestimators can be regarded as a +# hyperparameter. We can explore the combinations of columns' choices and values +# of other hyperparameters with the help of :class:`~.model_selection.GridSearchCV`. + +param_grid = { + "classifier__estimators": [ + [ + ("gnb", GaussianNB(), ["age", "fare"]), + ("cnb", CategoricalNB(), categorical_features), + ], + [("gnb", GaussianNB(), []), ("cnb", CategoricalNB(), ["pclass"])], + [("gnb", GaussianNB(), ["embarked"]), ("cnb", CategoricalNB(), [])], + ], + "preprocessor__num__strategy": ["mean", "most_frequent"], +} + +grid_search = GridSearchCV(pipe, param_grid, cv=10) +grid_search + +# %% +# Calling `fit` triggers the cross-validated search for the best +# hyperparameters combination: + +grid_search.fit(X_train, y_train) + +print("Best params:") +print(grid_search.best_params_) + +# %% +# As it turns out, the best results are achieved by the naive Bayes model when "sex" +# is the only feature used: + +cv_results = pd.DataFrame(grid_search.cv_results_) +cv_results = cv_results.sort_values("mean_test_score", ascending=False) +cv_results["Columns dictionary"] = cv_results["param_classifier__estimators"].map( + lambda l: {e[0]: e[-1] for e in l} +) +cv_results["'gnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["gnb"]) +cv_results["'cnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["cnb"]) +cv_results[ + [ + "mean_test_score", + "std_test_score", + "param_preprocessor__num__strategy", + "'gnb' columns", + "'cnb' columns", + ] +] diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 9ee664bf8b3a4..eb7b8ee7de2bd 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -11,6 +11,7 @@ # Lars Buitinck # Jan Hendrik Metzen # (parts based on earlier work by Mathieu Blondel) +# Andrey V. Melnik # # License: BSD 3 clause import warnings @@ -20,12 +21,24 @@ import numpy as np from scipy.special import logsumexp -from .base import BaseEstimator, ClassifierMixin, _fit_context +from .base import BaseEstimator, ClassifierMixin, _fit_context, clone +from .compose._column_transformer import _is_empty_column_selection from .preprocessing import LabelBinarizer, binarize, label_binarize +from .utils import Bunch, _get_column_indices, _print_elapsed_time, _safe_indexing +from .utils._encode import _unique +from .utils._estimator_html_repr import _VisualBlock from .utils._param_validation import Hidden, Interval, StrOptions from .utils.extmath import safe_sparse_dot +from .utils.metaestimators import _BaseComposition, available_if from .utils.multiclass import _check_partial_fit_first_call -from .utils.validation import _check_sample_weight, check_is_fitted, check_non_negative +from .utils.parallel import Parallel, delayed +from .utils.validation import ( + _check_sample_weight, + check_array, + check_is_fitted, + check_non_negative, + column_or_1d, +) __all__ = [ "BernoulliNB", @@ -33,6 +46,7 @@ "MultinomialNB", "ComplementNB", "CategoricalNB", + "ColumnwiseNB", ] @@ -1526,3 +1540,609 @@ def _joint_log_likelihood(self, X): jll += self.feature_log_prob_[i][:, indices].T total_ll = jll + self.class_log_prior_ return total_ll + + +class _select_half: + """Column selector that selects the first half of columns + + Used for testing purposes only. + """ + + def __init__(self, half="first"): + self.half = half + + def __repr__(self): + # Only required when using pytest-xdist to get an id not associated + # with the memory location. See: + # https://github.com/scikit-learn/scikit-learn/pull/18811#issuecomment-727226988 + return f'_select_half("{str(self.half)}")' + + def __call__(self, X): + if self.half == "first": + return list(range((X.shape[1] + 1) // 2)) + else: + return list(range((X.shape[1] + 1) // 2, X.shape[1])) + + +def _estimators_have(attr): + """Check if all self.estimators or self.estimators_ have attr. + + Used together with `available_if` in `ColumnwiseNB`.""" + + # This function is used with `_available_if` before validation. + # The try statement suppresses errors caused by incorrect specification of + # self.estimators. Informative errors are raised at validation elsewhere. + def chk(obj): + try: + if hasattr(obj, "estimators_"): + out = all(hasattr(triplet[1], attr) for triplet in obj.estimators_) + else: + out = all(hasattr(triplet[1], attr) for triplet in obj.estimators) + except (TypeError, IndexError, AttributeError): + return False + return out + + return chk + + +def _fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): + """Call ``estimator.fit`` and print elapsed time message. + + See :func:`sklearn.pipeline._fit_one`. + """ + # The dummy parameter is needed in _fit_partial to factorise fit/fit_partial + if fit_params["classes"] is None: + fit_params.pop("classes") + with _print_elapsed_time(message_clsname, message): + return estimator.fit(X, y, **fit_params) + + +def _partial_fit_one(estimator, X, y, message_clsname="", message=None, **fit_params): + """Call ``estimator.partial_fit`` and print elapsed time message. + + See :func:`sklearn.pipeline._fit_one`. + """ + with _print_elapsed_time(message_clsname, message): + return estimator.partial_fit(X, y, **fit_params) + + +def _jll_one(estimator, X): + """Call ``estimator.predict_joint_log_proba``. + + See :func:`sklearn.pipeline._transform_one`. + """ + return estimator.predict_joint_log_proba(X) + + +class ColumnwiseNB(_BaseNB, _BaseComposition): + """Column-wise Naive Bayes meta-estimator. + + This estimator combines various naive Bayes estimators by applying them + to different column subsets of the input and joining their predictions + according to the naive Bayes assumption. This is useful when features are + heterogeneous and follow different kinds of distributions. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Parameters + ---------- + estimators : list of tuples + List of `(name, naive_bayes_estimator, columns)` tuples specifying the naive + Bayes estimators to be combined into a single naive Bayes meta-estimator. + + name : str + Name of the naive Bayes estimator, by which the subestimator and + its parameters can be set using :term:`set_params` and searched in + grid search. + naive_bayes_estimator : estimator + The estimator must support :term:`fit` or :term:`partial_fit`, + depending on how the meta-estimator is fitted. In addition, the + estimator must support `predict_joint_log_proba` method, which + returns a numpy array of shape (n_samples, n_classes) containing + joint log-probabilities, `log P(x,y)` for each sample point and class. + columns : str, array-like of str, int, array-like of int, \ + array-like of bool, slice or callable + Indexes the data on its second axis. Integers are interpreted as + positional columns, while strings can reference DataFrame columns + by name. A scalar string or int should be used where + `naive_bayes_estimator` expects X to be a 1d array-like (vector), + otherwise a 2d array will be passed to the transformer. + A callable is passed the input data `X` and can return any of the + above. To select multiple columns by name or dtype, you can use + :obj:`~sklearn.compose.make_column_selector`. The callable is evaluated + on the first batch, but not on subsequent calls of `partial_fit`. + + priors : array-like of shape (n_classes,) or str, default=None + Prior probabilities of classes. If unspecified, the priors are + calculated as relative frequencies of classes in the training data. + If str, the priors are taken from the estimator with the given name. + If array-like, the same priors might have to be specified manually in + each subestimator, in order to ensure consistent predictions. + + n_jobs : int, default=None + Number of jobs to run in parallel. Appropriate fit or predict methods + of subestimators are invoked in parallel. + `None` means 1 unless in a :obj:`joblib.parallel_backend` context. + `-1` means using all processors. See :term:`Glossary ` + for more details. + + verbose : bool, default=False + If True, the time elapsed while fitting each estimator will be + printed as it is completed. + + Attributes + ---------- + estimators_ : list of tuples + List of `(name, fitted_estimator, columns)` tuples, which follow + the order of `estimators`. Here, `fitted_estimator` is a fitted naive + Bayes estimator, except when `columns` presents an empty selection of + columns, in which case it is the original unfitted `naive_bayes_estimator`. + If the original specification of `columns` in `estimators` was a + callable, then `columns` is converted to a list of column indices. + + named_estimators_ : :class:`~sklearn.utils.Bunch` + Read-only attribute to access any subestimator by given name. + Keys are estimator names and values are the fitted estimators, except + when a subestimator does not require fitting (i.e., when `columns` is + an empty set of indices). + + class_prior_ : ndarray of shape (n_classes,) + Prior probabilities of classes used in the naive Bayes meta-estimator, + which are calculated as relative frequencies, extracted from + subestimators, or provided, according to the value of `priors` + at initialization. + + class_count_ : ndarray of shape (n_classes,) + Number of samples encountered for each class during fitting. This + value is weighted by the sample weight when provided. + + classes_ : ndarray of shape (n_classes,) + Class labels known to the classifier. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Only defined if `X` has + feature names that are all strings. + + See Also + -------- + BernoulliNB : Naive Bayes classifier for multivariate Bernoulli models. + CategoricalNB : Naive Bayes classifier for categorical features. + ComplementNB : Complement Naive Bayes classifier. + MultinomialNB : Naive Bayes classifier for multinomial models. + GaussianNB : Gaussian Naive Bayes. + :class:`~sklearn.compose.ColumnTransformer` : Applies transformers to columns. + + Notes + ----- + ColumnwiseNB combines multiple naive Bayes estimators by expressing the + overall joint probability `P(x,y)` through `P(x_i,y)`, the joint + probabilities of the subestimators:: + + Log P(x,y) = Log P(x_1,y) + ... + Log P(x_N,y) - (N - 1) Log P(y), + + where `N` denotes `n_estimators`, the number of estimators. + It is implicitly assumed that the class log priors are finite and agree + between the estimators and the subestimator:: + + - inf < Log P(y) = Log P(y|1) = ... = Log P(y|N). + + The meta-estimators does not check if this condition holds. Meaningless + results, including `NaN`, may be produced by ColumnwiseNB if the class + priors differ or contain a zero probability. + + Examples + -------- + >>> import numpy as np + >>> rng = np.random.RandomState(1) + >>> X = rng.randint(5, size=(6, 100)) + >>> y = np.array([0, 0, 1, 1, 2, 2]) + >>> from sklearn.naive_bayes import MultinomialNB, GaussianNB, ColumnwiseNB + >>> clf = ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ... ('mnb2', MultinomialNB(), [3, 4]), + ... ('gnb1', GaussianNB(), [5])]) + >>> clf.fit(X, y) + ColumnwiseNB(estimators=[('mnb1', MultinomialNB(), [0, 1]), + ('mnb2', MultinomialNB(), [3, 4]), + ('gnb1', GaussianNB(), [5])]) + >>> print(clf.predict(X)) + [0 0 1 0 2 2] + """ + + _required_parameters = ["estimators"] + + _parameter_constraints = { + "estimators": [list], + "priors": ["array-like", str, None], + "n_jobs": [Integral, None], + "verbose": ["verbose"], + } + + def _log_message(self, name, idx, total): + if not self.verbose: + return None + return f"({idx} of {total}) Processing {name}" + + def __init__(self, estimators, *, priors=None, n_jobs=None, verbose=False): + self.estimators = estimators + self.priors = priors + self.n_jobs = n_jobs + self.verbose = verbose + + def _check_X(self, X): + """Validate X, used only in predict* methods.""" + # Defer conversion and validation of a pandas DataFrame to subestimators, + # in order to allow column indexing by str or int (if DataFrame). + # Convert other kinds here to allow column indexing by int (otherwise). + # Note that subestimators may modify (a copy of) X. For example, + # BernoulliNB._check_X binarises the input. + X = self._check_array_if_not_pandas(X) + self._check_feature_names(X, reset=False) + self._check_n_features(X, reset=False) + return X + + def _check_array_if_not_pandas(self, array): + """Convert to ndarray, unless a pandas DataFrame""" + if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"): + return array + else: + return check_array(array) + + def _joint_log_likelihood(self, X): + """Calculate the meta-estimator's joint log-probability `log P(x,y)`.""" + estimators = self._iter(fitted=True, replace_strings=True) + all_jlls = Parallel(n_jobs=self.n_jobs)( + delayed(_jll_one)(estimator=nb_estimator, X=_safe_indexing(X, cols, axis=1)) + for (_, nb_estimator, cols) in estimators + ) + n_estimators = len(all_jlls) + log_prior = np.log(self.class_prior_) + return np.where( + np.isinf(log_prior), + -np.inf, + np.sum(all_jlls, axis=0) - (n_estimators - 1) * log_prior, + ) + + def _validate_estimators(self, check_partial=False): + try: + names, estimators, _ = zip(*self.estimators) + except (TypeError, AttributeError, ValueError) as exc: + raise ValueError( + "A list of naive Bayes estimators must be provided " + "in the form [(name, naive_bayes_estimator, columns), ... ]." + ) from exc + for e in estimators: + if (not check_partial) and ( + not (hasattr(e, "fit") and hasattr(e, "predict_joint_log_proba")) + ): + raise TypeError( + "Estimators must be naive Bayes estimators implementing " + "`fit` and `predict_joint_log_proba` methods." + ) + if check_partial and not hasattr(e, "predict_joint_log_proba"): + raise TypeError( + "Estimators must be Naive Bayes estimators implementing " + "`partial_fit` and `predict_joint_log_proba` methods." + ) + self._validate_names(names) + + def _validate_column_callables(self, X): + """ + Convert callable column specifications and store into self._columns. + + Empty-set columns do not enjoy any special treatment. + """ + all_columns = [] + estimator_to_input_indices = {} + for name, _, columns in self.estimators: + if callable(columns): + columns = columns(X) + all_columns.append(columns) + estimator_to_input_indices[name] = _get_column_indices(X, columns) + self._columns = all_columns + self._estimator_to_input_indices = estimator_to_input_indices + + def _iter(self, *, fitted, replace_strings): + """Generate `(name, naive_bayes_estimator, columns)` tuples. + + This is a private method, similar to ColumnTransformer._iter. + Must not be called before _validate_column_callables. + + Parameters + ---------- + fitted : bool, default=False + If False, returns tuples from self.estimators (user-specified), but + callable columns are replaced with a list column names or indices. + If True, returns tuples from self.estimators_ (fitted), where + columns are processed as well. + + replace_strings : bool, default=False + If True, omits the estimators that do not require fitting, i.e those + with empty-set columns. The name `replace_strings` is a relic of + ColumnTransformer implementation, where `passthrough` and `drop` + required replacement and omission, respectively. + + Yields + ------ + tuple + of the form `(name, naive_bayes_estimator, columns)`. + + Notes + ----- + Loop through estimators from this generator with the following + parameters, depending on the purpose: + + self._iter(fitted=False, replace_strings=True) : + fit, 1st partial_fit + self._iter(fitted=True, replace_strings=True) : + further partial_fit, predict + self._iter(fitted=False, replace_strings=False) : + update fitted estimators. Note that special treatment is required + for unfitted estimators (those with empty-set columns)! + self._iter(fitted=True, replace_strings=False) : + not used here. The usecase in ColumnTransformer would be sorting + out the transformed output and its column names. + do not use in : + a Bunch accessor named_estimators_; + input validation _validate_estimators, _validate_column_callables; + parameter management: get_params_, set_params_, _estimators. + """ + if fitted: + for name, estimator, cols in self.estimators_: + if replace_strings and _is_empty_column_selection(cols): + continue + else: + yield (name, estimator, cols) + else: # fitted=False + for (name, estimator, _), cols in zip(self.estimators, self._columns): + if replace_strings and _is_empty_column_selection(cols): + continue + else: + yield (name, estimator, cols) + + def _update_class_prior(self): + """Update class prior after most of the fitting as done.""" + if self.priors is None: # calculate empirical prior from counts + priors = self.class_count_ / self.class_count_.sum() + elif isinstance(self.priors, str): # extract prior from estimator + name = self.priors + e = self.named_estimators_[name] + if getattr(e, "class_prior_", None) is not None: + priors = e.class_prior_ + elif getattr(e, "class_log_prior_", None) is not None: + priors = np.exp(e.class_log_prior_) + else: + raise AttributeError( + f"Unable to extract class prior from estimator {name}, as " + "it does not have class_prior_ or class_log_prior_ " + "attributes." + ) + else: # check the provided prior + priors = np.asarray(self.priors) + # Check the prior in any case. + if len(priors) != len(self.classes_): + raise ValueError("Number of priors must match number of classes.") + if not np.isclose(priors.sum(), 1.0): + raise ValueError("The sum of the priors should be 1.") + if (priors < 0).any(): + raise ValueError("Priors must be non-negative.") + self.class_prior_ = priors + + def _update_fitted_estimators(self, fitted_estimators): + """Update tuples in self.estimators_ with fitted_estimators provided. + + Callable columns are replaced with sets of actual str or int indices. + Estimators that don't require fitting are passed as they were, + without cloning. + """ + estimators_ = [] + fitted_estimators = iter(fitted_estimators) + + for name, nb_estimator, cols in self._iter(fitted=False, replace_strings=False): + if not _is_empty_column_selection(cols): + updated_nb_estimator = next(fitted_estimators) + else: # don't advance fitted_estimators; use original + updated_nb_estimator = nb_estimator + estimators_.append((name, updated_nb_estimator, cols)) + self.estimators_ = estimators_ + self.named_estimators_ = Bunch(**{name: e for name, e, _ in estimators_}) + + def _partial_fit(self, X, y, partial=False, classes=None, sample_weight=None): + """ + partial : bool, default=False + True for partial_fit, False for fit. + """ + X = self._check_array_if_not_pandas(X) + first_call = not hasattr(self, "classes_") + if first_call: # in fit() or the first call of partial_fit() + self._check_feature_names(X, reset=True) + self._check_n_features(X, reset=True) + self._validate_estimators(check_partial=partial) + self._validate_column_callables(X) + else: + self._check_feature_names(X, reset=False) + self._check_n_features(X, reset=False) + + y_ = column_or_1d(y) + + if sample_weight is not None: + weights = _check_sample_weight(sample_weight, X=y_, copy=True) + + if not partial: + self.classes_, counts = _unique(y_, return_counts=True) + else: + _check_partial_fit_first_call(self, classes) + + if sample_weight is not None: + counts = np.zeros(len(self.classes_), dtype=np.float64) + for i, c in enumerate(self.classes_): + counts[i] = (weights * (y_ == c)).sum() + elif partial: + counts = np.zeros(len(self.classes_), dtype=np.float64) + for i, c in enumerate(self.classes_): + counts[i] = (y_ == c).sum() + + if not first_call: + self.class_count_ += counts + else: + self.class_count_ = counts.astype(np.float64, copy=False) + + estimators = list(self._iter(fitted=not first_call, replace_strings=True)) + fitted_estimators = Parallel(n_jobs=self.n_jobs)( + delayed(_partial_fit_one if partial else _fit_one)( + estimator=clone(nb_estimator) if first_call else nb_estimator, + X=_safe_indexing(X, cols, axis=1), + y=y, + message_clsname="ColumnwiseNB", + message=self._log_message(name, idx, len(estimators)), + classes=classes, + sample_weight=sample_weight, + ) + for idx, (name, nb_estimator, cols) in enumerate(estimators, 1) + ) + self._update_fitted_estimators(fitted_estimators) + self._update_class_prior() + return self + + @_fit_context( + # estimators in ColumnwiseNB.estimators are not validated yet + prefer_skip_nested_validation=False + ) + def fit(self, X, y, sample_weight=None): + """Fit the naive Bayes meta-estimator. + + Calls `fit` of each subestimator `naive_bayes_estimator`. + Only a corresponding subset of columns of `X` is passed to each subestimator; + `sample_weight` and `y` are passed to the subestimators as they are. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples + and `n_features` is the number of features. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), default=None + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + Returns the instance itself. + """ + if hasattr(self, "classes_"): + delattr(self, "classes_") + return self._partial_fit( + X, y, partial=False, classes=None, sample_weight=sample_weight + ) + + @available_if(_estimators_have("partial_fit")) + @_fit_context( + # estimators in ColumnwiseNB.estimators are not validated yet + prefer_skip_nested_validation=False + ) + def partial_fit(self, X, y, classes=None, sample_weight=None): + """Fit incrementally the naive Bayes meta-estimator on a batch of samples. + + Calls `partial_fit` of each subestimator. Only a corresponding + subset of columns of `X` is passed to each subestimator. `classes`, + `sample_weight` and 'y' are passed to the subestimators as they are. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : array-like of shape (n_samples,) + Target values. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Weights applied to individual samples (1. for unweighted). + + Returns + ------- + self : object + Returns the instance itself. + """ + return self._partial_fit( + X, y, partial=True, classes=classes, sample_weight=sample_weight + ) + + @property + def _estimators(self): + """Internal list of subestimators. + + This is for the implementation of get_params via BaseComposition._get_params, + which expects lists of tuples of len 2. + """ + try: + return [(name, e) for name, e, _ in self.estimators] + except (TypeError, ValueError): + # This try-except clause is needed to pass the test from test_common.py: + # test_estimators_do_not_raise_errors_in_init_or_set_params(). + # ColumnTransformer does the same. See PR #21355 for details. + return self.estimators + + @_estimators.setter + def _estimators(self, value): + self.estimators = [ + (name, e, col) for ((name, e), (_, _, col)) in zip(value, self.estimators) + ] + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Returns the parameters listed in the constructor as well as the + subestimators contained within the `estimators` of the `ColumnwiseNB` + instance. + + Parameters + ---------- + deep : bool, default=True + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + return self._get_params("_estimators", deep=deep) + + def set_params(self, **kwargs): + """Set the parameters of this estimator. + + Valid parameter keys can be listed with `get_params()`. Note that you + can directly set the parameters of the estimators contained in + `estimators` of `ColumnwiseNB`. + + Parameters + ---------- + **kwargs : dict + Estimator parameters. + + Returns + ------- + self : ColumnwiseNB + This estimator. + """ + self._set_params("_estimators", **kwargs) + return self + + def _sk_visual_block_(self): + """HTML representation of this estimator.""" + names, estimators, name_details = zip(*self.estimators) + return _VisualBlock( + "parallel", estimators, names=names, name_details=name_details + ) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index b3c6820faefc2..ae1569d85f5de 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -255,6 +255,7 @@ def _generate_meta_estimator_instances_with_pipeline(): "BaggingClassifier", "BaggingRegressor", "ClassifierChain", # data validation is necessary + "ColumnwiseNB", "IterativeImputer", "OneVsOneClassifier", # input validation can't be avoided "RANSACRegressor", diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 4165aa7a668e6..2e980d00bbb46 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -1,15 +1,20 @@ import re import warnings +from itertools import chain import numpy as np import pytest from scipy.special import logsumexp +from sklearn.base import BaseEstimator, clone +from sklearn.compose import make_column_selector from sklearn.datasets import load_digits, load_iris +from sklearn.exceptions import DataConversionWarning from sklearn.model_selection import cross_val_score, train_test_split from sklearn.naive_bayes import ( BernoulliNB, CategoricalNB, + ColumnwiseNB, ComplementNB, GaussianNB, MultinomialNB, @@ -20,6 +25,7 @@ assert_array_almost_equal, assert_array_equal, ) +from sklearn.utils.estimator_checks import check_param_validation from sklearn.utils.fixes import CSR_CONTAINERS DISCRETE_NAIVE_BAYES_CLASSES = [BernoulliNB, CategoricalNB, ComplementNB, MultinomialNB] @@ -991,3 +997,817 @@ def test_predict_joint_proba(Estimator, global_random_seed): log_prob_x = logsumexp(jll, axis=1) log_prob_x_y = jll - np.atleast_2d(log_prob_x).T assert_allclose(est.predict_log_proba(X2), log_prob_x_y) + + +def test_cwnb_union_gnb(): + # A union of GaussianNB's yields the same prediction as a single GaussianNB + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [0]), ("g2", GaussianNB(), [1])] + ) + clf2 = GaussianNB() + clf1.fit(X, y) + clf2.fit(X, y) + assert_array_almost_equal(clf1.predict(X), clf2.predict(X), 8) + assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + + +def test_cwnb_union_prior_gnb(): + # A union of GaussianNB's yields the same prediction as a single GaussianNB + # when class priors are provided by user + priors = np.array([1 / 3, 2 / 3]) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(priors=priors), [0]), + ("g2", GaussianNB(priors=priors), [1]), + ], + priors=priors, + ) + clf2 = GaussianNB(priors=priors) + clf1.fit(X, y) + clf2.fit(X, y) + assert_array_almost_equal(clf1.predict(X), clf2.predict(X), 8) + assert_array_almost_equal(clf1.predict_proba(X), clf2.predict_proba(X), 8) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + + +def test_cwnb_union_bnb_fit(global_random_seed): + # A union of BernoulliNB's yields the same prediction as a single BernoulliNB + # (fit) + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + ) + clf2 = BernoulliNB() + clf1.fit(X1, y1) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) + # BernoulliNB is likely to yield 1/1 odds that eventually round to different + # y_pred as a result of unavoidable log/exp transformations done by ColumnwiseNB. + # We don't want to test for these discretisation discrepancies. + ii = abs(clf1.predict_proba(X1)[:, 0] - clf1.predict_proba(X1)[:, 1]) > 1e-8 + assert_array_almost_equal(clf1.predict(X1)[ii], clf2.predict(X1)[ii], 8) + + +def test_cwnb_union_bnb_partial_fit(global_random_seed): + # A union of BernoulliNB's yields the same prediction as a single BernoulliNB + # (partial_fit) + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [0]), ("b2", BernoulliNB(), [1, 2])] + ) + clf2 = BernoulliNB() + clf1.partial_fit(X1[:5], y1[:5], classes=[0, 1]) + clf1.partial_fit(X1[5:], y1[5:]) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_proba(X1), clf2.predict_proba(X1), 8) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) + # BernoulliNB is likely to yield 1/1 odds that eventually round to different + # y_pred as a result of unavoidable log/exp transformations done by ColumnwiseNB. + # We don't want to test for these discretisation discrepancies. + ii = abs(clf1.predict_proba(X1)[:, 0] - clf1.predict_proba(X1)[:, 1]) > 1e-8 + assert_array_almost_equal(clf1.predict(X1)[ii], clf2.predict(X1)[ii], 8) + + +def test_cwnb_union_permutation(global_random_seed): + # A union of several different NB's is permutation-invariant + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [3]), + ("g1", GaussianNB(), [0]), + ("m1", MultinomialNB(), [0, 2]), + ("b2", BernoulliNB(), [1]), + ] + ) + # permute (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) both estimator specs and column numbers + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [3]), + ("g1", GaussianNB(), [1]), + ("m1", MultinomialNB(), [1, 0]), + ("b2", BernoulliNB(), [2]), + ] + ) + clf1.fit(X2[:, [0, 1, 2, 3, 4]], y2) # (0, 1, 2, 3, 4) -> (1, 2, 0, 3, 4) + clf2.fit(X2[:, [2, 0, 1, 3, 4]], y2) # (0, 1, 2, 3, 4) <- (2, 0, 1, 3, 4) + assert_array_almost_equal( + clf1.predict_proba(X2[:, [0, 1, 2, 3, 4]]), + clf2.predict_proba(X2[:, [2, 0, 1, 3, 4]]), + 8, + ) + assert_array_almost_equal( + clf1.predict_log_proba(X2[:, [0, 1, 2, 3, 4]]), + clf2.predict_log_proba(X2[:, [2, 0, 1, 3, 4]]), + 8, + ) + assert_array_almost_equal( + clf1.predict(X2[:, [0, 1, 2, 3, 4]]), clf2.predict(X2[:, [2, 0, 1, 3, 4]]), 8 + ) + + +def test_cwnb_estimators_pandas(): + pd = pytest.importorskip("pandas") + Xdf = pd.DataFrame(data=X, columns=["col0", "col1"]) + ydf = pd.DataFrame({"target": y}) + + # Subestimators spec: cols can be lists of int or lists of str, if DataFrame + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), ["col1"]), + ("g2", GaussianNB(), ["col0", "col1"]), + ] + ) + clf1.fit(X, y) + clf2.fit(Xdf, y) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(Xdf), 8) + msg = "A column-vector y was passed when a 1d array was expected" + with pytest.warns(DataConversionWarning, match=msg): + clf2.fit(Xdf, ydf) + assert_array_almost_equal( + clf1.predict_log_proba(X), clf2.predict_log_proba(Xdf), 8 + ) + + # Subestimators spec: empty cols have the same effect as an absent estimator + # when callable columns produce the empty set. + select_none = make_column_selector(pattern="qwerasdf") + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB(), select_none), + ("g3", GaussianNB(), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + ) + clf1.fit(Xdf, y) + clf2.fit(Xdf, y) + assert_array_almost_equal( + clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 + ) + # Empty-columns estimators are passed to estimators_ and the numbers match + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 + # No cloning of the empty-columns estimators took place: + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) + + # Subestimators spec: test callable columns + select_int = make_column_selector(dtype_include=np.int_) + select_float = make_column_selector(dtype_include=np.float_) + Xdf2 = Xdf + Xdf2["col3"] = np.exp(Xdf["col0"]) - 0.5 * Xdf["col1"] + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), ["col3"]), + ("m1", BernoulliNB(), ["col0", "col1"]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), select_float), + ("g2", BernoulliNB(), select_int), + ] + ) + clf1.fit(Xdf, y) + clf2.fit(Xdf, y) + assert_array_almost_equal( + clf1.predict_log_proba(Xdf), clf2.predict_log_proba(Xdf), 8 + ) + + +def test_cwnb_repeated_columns(global_random_seed): + # Subestimators spec: repeated col ints have the same effect as repeating data + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1, 1]), + ("b1", BernoulliNB(), [0, 0, 1, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("b1", BernoulliNB(), [2, 3, 4, 5]), + ] + ) + clf1.fit(X1, y1) + clf2.fit(X1[:, [1, 1, 0, 0, 1, 1]], y1) + assert_array_almost_equal( + clf1.predict_log_proba(X1), clf2.predict_log_proba(X1[:, [1, 1, 0, 0, 1, 1]]), 8 + ) + + +def test_cwnb_empty_columns(global_random_seed): + # Subestimators spec: empty cols have the same effect as an absent estimator + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB(), []), + ("g3", GaussianNB(), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g3", GaussianNB(), [0, 1])] + ) + clf1.fit(X1, y1) + clf2.fit(X1, y1) + assert_array_almost_equal(clf1.predict_log_proba(X1), clf2.predict_log_proba(X1), 8) + # Empty-columns estimators are passed to estimators_ and the numbers match + assert len(clf1.estimators) == len(clf1.estimators_) == 3 + assert len(clf2.estimators) == len(clf2.estimators_) == 2 + # No cloning of the empty-columns estimators took place: + assert id(clf1.estimators[1][1]) == id(clf1.named_estimators_["g2"]) + + +def test_cwnb_estimators_unique_names(): + # Subestimators spec: error on repeated names + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g1", GaussianNB(), [0, 1])] + ) + msg = "Names provided are not unique" + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB( + estimators=[["g1", GaussianNB(), [1]], ["g2", GaussianNB(), [0, 1]]] + ) + clf1.fit(X, y) + + +def test_cwnb_estimators_nonempty_list(global_random_seed): + # Subestimators spec: error on empty list + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf = ColumnwiseNB( + estimators=[], + ) + msg = "A list of naive Bayes estimators must be provided*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) + + # Subestimators spec: error on None + clf = ColumnwiseNB( + estimators=None, + ) + msg = "The 'estimators' parameter of ColumnwiseNB must be an instance of 'list'*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) + + # Subestimators spec: error on non-tuple + clf = ColumnwiseNB( + estimators=GaussianNB(), + ) + msg = "The 'estimators' parameter of ColumnwiseNB must be an instance of 'list'*" + with pytest.raises(ValueError, match=msg): + clf.fit(X1, y1) + + +def test_cwnb_estimators_support_jll(): + # Subestimators spec: error when some don't support predict_joint_log_proba + class notNB(BaseEstimator): + def __init__(self): + pass + + def fit(self, X, y): + pass + + def partial_fit(self, X, y): + pass + + # def predict_joint_log_proba(self, X): pass + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.partial_fit(X, y) + + +def test_cwnb_estimators_support_fit(): + # Subestimators spec: error when some don't support fit + class notNB(BaseEstimator): + def __init__(self): + pass + + # def fit(self, X, y): pass + def partial_fit(self, X, y): + pass + + def predict_joint_log_proba(self, X): + pass + + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.fit(X, y) + + delattr(notNB, "predict_joint_log_proba") + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + msg = "Estimators must be .aive Bayes estimators implementing *" + with pytest.raises(TypeError, match=msg): + clf1.fit(X, y) + + +def test_cwnb_estimators_support_partial_fit(): + # Subestimators spec: error when some don't support partial_fit + class notNB(BaseEstimator): + def __init__(self): + pass + + def fit(self, X, y): + pass + + # def partial_fit(self, X, y): pass + def predict_joint_log_proba(self, X): + pass + + def predict(self, X): + pass + + clf1 = ColumnwiseNB(estimators=[["g1", notNB(), [1]], ["g2", GaussianNB(), [0]]]) + msg = "This 'ColumnwiseNB' has no attribute 'partial_fit'*" + with pytest.raises(AttributeError, match=msg): + clf1.partial_fit(X, y) + + +def test_cwnb_estimators_setter(global_random_seed): + # _estimators setter works + X1, y1 = get_random_normal_x_binary_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [0]), ("b1", BernoulliNB(), [1])] + ) + clf1.fit(X1, y1) + clf1._estimators = [ + ("x1", clf1.named_estimators_["g1"]), + ("x2", clf1.named_estimators_["g1"]), + ] + assert clf1.estimators[0][0] == "x1" + assert clf1.estimators[0][1] is clf1.named_estimators_["g1"] + assert clf1.estimators[1][0] == "x2" + assert clf1.estimators[1][1] is clf1.named_estimators_["g1"] + + +def test_cwnb_prior_valid_spec(): + # prior spec: error when negative, sum!=1 or bad length + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([-0.25, 1.25]), + ) + msg = "Priors must be non-negative." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([0.25, 0.7]), + ) + msg = "The sum of the priors should be 1." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])], + priors=np.array([0.25, 0.25, 0.25, 0.25]), + ) + msg = "Number of priors must match number of classes." + with pytest.raises(ValueError, match=msg): + clf1.fit(X, y) + + +def test_cwnb_prior_match(global_random_seed): + # prior spec: all these ways work (and agree in our example) + # (1) an array of values + # (2a) a str name of a subestimator supporting class_prior_ + # (2b) a str name of a subestimator supporting class_log_prior_ + # (3) nothing (ColumnwiseNB will calculate relative frequencies) + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], + priors=np.array([1 / 3, 1 / 3, 1 / 3]), # prior is provided by user + ) + clf2a = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], + priors="g1", # prior will be estimated by sub-estimator "g1" + ) + clf2b = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], + priors="m1", # prior will be estimated by sub-estimator "m1" + ) + clf3 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0, 1]), + ("m1", MultinomialNB(), [2, 3, 4, 5, 6]), + ], # prior will be estimated by the meta-estimator + ) + clf1.fit(X2, y2) + clf2a.fit(X2, y2) + clf2b.fit(X2, y2) + clf3.fit(X2, y2) + assert clf3.priors is None + assert_array_almost_equal( + clf1.class_prior_, clf1.named_estimators_["g1"].class_prior_, 8 + ) + assert_array_almost_equal( + np.log(clf1.class_prior_), clf1.named_estimators_["m1"].class_log_prior_, 8 + ) + assert_array_almost_equal(clf1.class_prior_, clf2a.class_prior_, 8) + assert_array_almost_equal(clf1.class_prior_, clf2b.class_prior_, 8) + assert_array_almost_equal(clf1.class_prior_, clf3.class_prior_, 8) + + +def test_cwnb_estimators_support_class_prior_gnb(): + # prior spec: error message when can't extract prior from subestimator + # ColumnwiseNB tries both class_prior_ and class_log_prior, which is tested + # in test_cwnb_prior_match() + class GaussianNB_hide_prior(GaussianNB): + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=None) + self.qwerqwer = self.class_prior_ + del self.class_prior_ + + clf = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1]), + ("g2", GaussianNB_hide_prior(), [0, 1]), + ], + priors="g2", + ) + msg = "Unable to extract class prior from estimator g2*" + with pytest.raises(AttributeError, match=msg): + clf.fit(X, y) + + +def test_cwnb_estimators_support_class_prior_mnb(global_random_seed): + # prior spec: error message when can't extract prior from subestimator + # ColumnwiseNB tries both class_prior_ and class_log_prior, which is tested + # in test_cwnb_prior_match() + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + + class MultinomialNB_hide_log_prior(MultinomialNB): + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=None) + self.qwerqwer = self.class_log_prior_ + del self.class_log_prior_ + + clf = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [0]), + ("m1", MultinomialNB_hide_log_prior(), [1, 2, 3, 4, 5]), + ], + priors="m1", + ) + msg = "Unable to extract class prior from estimator m1*" + with pytest.raises(AttributeError, match=msg): + clf.fit(X2, y2) + + +def test_cwnb_prior_nonzero(global_random_seed): + # P(y)=0 in one or two subestimators results in P(y|x)=0 of meta-estimator. + # Despite attempted Log[0], predicted class probabilities are all finite. + # On a related note, meaningless results (including NaNs) may be produced + # - if P(y)=0 in the meta-estimator, or/and + # - if class priors differ across subestimators, + # but this is not what is tested here. + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + rng = np.random.RandomState(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(), [1, 3, 5]), + ("g2", GaussianNB(priors=np.array([0.5, 0, 0.5])), [0, 1]), + ] + ) + clf1.fit(X2, y2) + msg = "divide by zero encountered in log" + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(X2)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + Xt = rng.randint(5, size=(6, 100)) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(Xt)[:, 1] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(priors=np.array([0.6, 0, 0.4])), [1, 3, 5]), + ("g2", GaussianNB(priors=np.array([0.5, 0.5, 0])), [0, 1]), + ] + ) + clf1.fit(X2, y2) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(X2)[:, 1:] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + Xt = rng.randint(5, size=(6, 100)) + with pytest.warns(RuntimeWarning, match=msg): + p = clf1.predict_proba(Xt)[:, 1:] + assert_almost_equal(np.abs(p).sum(), 0) + assert np.isfinite(p).all() + + +def test_cwnb_fit_sample_weight_ones(): + # weights in fit have no effect if all ones + weights = [1, 1, 1, 1, 1, 1] + clf1 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + clf2 = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + clf1.fit(X, y, sample_weight=weights) + clf2.fit(X, y) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X), clf2.predict_joint_log_proba(X), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), clf2.predict(X)) + + +def test_cwnb_partial_fit_sample_weight_ones(global_random_seed): + # weights in partial_fit have no effect if all ones + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + weights = [1, 1, 1, 1, 1, 1] + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) + clf2.partial_fit(X2, y2, classes=np.unique(y2)) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + + +def test_cwnb_fit_sample_weight_repeated(): + # weights in fit have the same effect as repeating data + weights = [1, 2, 3, 1, 4, 2] + idx = list(chain(*([i] * w for i, w in enumerate(weights)))) + # var_smoothing=0.0 is for maximum precision in dealing with a small sample + clf1 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(var_smoothing=0.0), [1]), + ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("g1", GaussianNB(var_smoothing=0.0), [1]), + ("g2", GaussianNB(var_smoothing=0.0), [0, 1]), + ] + ) + clf1.fit(X, y, sample_weight=weights) + clf2.fit(X[idx], y[idx]) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X), clf2.predict_joint_log_proba(X), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X), clf2.predict_log_proba(X), 8) + assert_array_equal(clf1.predict(X), clf2.predict(X), 8) + for attr_name in ("class_count_", "class_prior_", "classes_"): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + +def test_cwnb_partial_fit_sample_weight_repeated(global_random_seed): + # weights in partial_fit have the same effect as repeating data + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + weights = [1, 2, 3, 1, 4, 2] + idx = list(chain(*([i] * w for i, w in enumerate(weights)))) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf1.partial_fit(X2, y2, sample_weight=weights, classes=np.unique(y2)) + clf2.partial_fit(X2[idx], y2[idx], classes=np.unique(y2)) + assert_array_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2) + ) + assert_array_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2)) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + +def test_cwnb_partial_fit(global_random_seed): + # partial_fit: consecutive calls yield the same prediction as a single call + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf2 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf1.partial_fit(X2, y2, classes=np.unique(y2)) + clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) + clf2.partial_fit(X2[4:], y2[4:]) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + +def test_cwnb_fit_refits(global_random_seed): + # fit: re-fits the estimator de novo when called on a fitted estimator + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf2 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + clf1.fit(X2, y2) + clf2.partial_fit(X2[:4], y2[:4], classes=np.unique(y2)) + clf2.fit(X2, y2) + assert_array_almost_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + for attr_name in ("class_count_", "class_prior_", "classes_"): + assert_array_equal(getattr(clf1, attr_name), getattr(clf1, attr_name)) + + +def test_cwnb_partial_fit_classes(global_random_seed): + # partial_fit: error when classes are not provided at the first call + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[("b1", BernoulliNB(), [1]), ("m1", MultinomialNB(), [0, 2, 3])] + ) + msg = ".lasses must be passed on the first call to partial_fit" + with pytest.raises(ValueError, match=msg): + clf1.partial_fit(X2, y2) + + +def test_cwnb_class_attributes_consistency(global_random_seed): + # class_count_, classes_, class_prior_ are consistent in meta-, sub-estimators + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ] + ) + clf1.fit(X2, y2) + for se in clf1.named_estimators_: + assert_array_almost_equal( + clf1.class_count_, clf1.named_estimators_[se].class_count_, 8 + ) + assert_array_almost_equal(clf1.classes_, clf1.named_estimators_[se].classes_, 8) + assert_array_almost_equal( + np.log(clf1.class_prior_), clf1.named_estimators_[se].class_log_prior_, 8 + ) + + +def test_cwnb_params(global_random_seed): + # Can get and set subestimators' parameters through name__paramname + # clone() works on ColumnwiseNB + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(alpha=0.2, binarize=2), [1]), + ("m1", MultinomialNB(class_prior=[0.2, 0.2, 0.6]), [0, 2, 3]), + ] + ) + clf1.fit(X2, y2) + p = clf1.get_params(deep=True) + assert p["b1__alpha"] == 0.2 + assert p["b1__binarize"] == 2 + assert p["m1__class_prior"] == [0.2, 0.2, 0.6] + clf1.set_params(b1__alpha=123, m1__class_prior=[0.3, 0.3, 0.4]) + assert clf1.estimators[0][1].alpha == 123 + assert_array_equal(clf1.estimators[1][1].class_prior, [0.3, 0.3, 0.4]) + # After cloning and fitting, we can check through named_estimators, which + # maps to fitted estimators_: + clf2 = clone(clf1).fit(X2, y2) + assert clf2.named_estimators_["b1"].alpha == 123 + assert_array_equal(clf2.named_estimators_["m1"].class_prior, [0.3, 0.3, 0.4]) + assert id(clf2.named_estimators_["b1"]) != id(clf1.named_estimators_["b1"]) + + +def test_cwnb_n_jobs(global_random_seed): + # n_jobs: same result whether with it or without + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf1 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("b2", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ("m3", MultinomialNB(), slice(10, None)), + ], + n_jobs=4, + ) + clf2 = ColumnwiseNB( + estimators=[ + ("b1", BernoulliNB(binarize=2), [1]), + ("b2", BernoulliNB(binarize=2), [1]), + ("m1", MultinomialNB(), [0, 2, 3]), + ("m3", MultinomialNB(), slice(10, None)), + ] + ) + clf1.partial_fit(X2, y2, classes=np.unique(y2)) + clf2.partial_fit(X2, y2, classes=np.unique(y2)) + + assert_array_almost_equal( + clf1.predict_joint_log_proba(X2), clf2.predict_joint_log_proba(X2), 8 + ) + assert_array_almost_equal(clf1.predict_log_proba(X2), clf2.predict_log_proba(X2), 8) + assert_array_equal(clf1.predict(X2), clf2.predict(X2)) + + +def test_cwnb_example(): + # Test the Example from ColumnwiseNB docstring in naive_bayes.py + rng = np.random.RandomState(1) + X = rng.randint(5, size=(6, 100)) + y = np.array([0, 0, 1, 1, 2, 2]) + + clf = ColumnwiseNB( + estimators=[ + ("mnb1", MultinomialNB(), [0, 1]), + ("mnb2", MultinomialNB(), [3, 4]), + ("gnb1", GaussianNB(), [5]), + ] + ) + clf.fit(X, y) + clf.predict(X) + + +def test_cwnb_verbose(capsys, global_random_seed): + # Setting verbose=True does not result in an error. + # This DOES NOT test if the desired output is generated. + X2, y2 = get_random_integer_x_three_classes_y(global_random_seed) + clf = ColumnwiseNB( + estimators=[ + ("mnb1", MultinomialNB(), [0, 1]), + ("mnb2", MultinomialNB(), [3, 4]), + ("gnb1", GaussianNB(), [5]), + ], + verbose=True, + n_jobs=4, + ) + clf.fit(X2, y2) + clf.predict(X2) + + +def test_cwnb_sk_visual_block(capsys): + # visual block representation correctly extracts names, cols and estimators + estimators = (MultinomialNB(), MultinomialNB(), GaussianNB()) + clf = ColumnwiseNB( + estimators=[ + ("mnb1", estimators[0], [0, 1]), + ("mnb2", estimators[1], [3, 4]), + ("gnb1", estimators[2], [5]), + ], + ) + visual_block = clf._sk_visual_block_() + assert visual_block.names == ("mnb1", "mnb2", "gnb1") + assert visual_block.name_details == ([0, 1], [3, 4], [5]) + assert visual_block.estimators == estimators + + +def test_cwnb_check_param_validation(): + # This test replaces test_common.py::test_check_param_validation and is + # needed because utils.estimator_checks._construct_instance() is unable to + # create an instance of ColumnwiseNB (also of some other estimators, such as + # ColumnTransformer and Pipeline). + clf = ColumnwiseNB( + estimators=[("g1", GaussianNB(), [1]), ("g2", GaussianNB(), [0, 1])] + ) + check_param_validation("ColumnwiseNB", clf) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 43b53f5101dce..afae45598ac31 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -40,6 +40,7 @@ from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel from ..model_selection import ShuffleSplit, train_test_split from ..model_selection._validation import _safe_split +from ..naive_bayes import ColumnwiseNB, GaussianNB, _select_half from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale from ..random_projection import BaseRandomProjection @@ -60,10 +61,7 @@ from ..utils.validation import check_is_fitted from . import IS_PYPY, is_scalar_nan, shuffle from ._param_validation import Interval -from ._tags import ( - _DEFAULT_TAGS, - _safe_tags, -) +from ._tags import _DEFAULT_TAGS, _safe_tags from ._testing import ( SkipTest, _array_api_for_tests, @@ -427,8 +425,20 @@ def _construct_instance(Estimator): else: estimator = Estimator(LogisticRegression(C=1)) elif required_parameters in (["estimators"],): + if issubclass(Estimator, ColumnwiseNB): + # ColumnwiseNB (naive Bayes meta-classifier) + estimator = Estimator( + estimators=[ + ( + "gnb1", + GaussianNB(var_smoothing=1e-13), + _select_half("first"), + ), + ("gnb2", GaussianNB(), _select_half("second")), + ] + ) # Heterogeneous ensemble classes (i.e. stacking, voting) - if issubclass(Estimator, RegressorMixin): + elif issubclass(Estimator, RegressorMixin): estimator = Estimator( estimators=[("est1", Ridge(alpha=0.1)), ("est2", Ridge(alpha=1))] )