Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ Changelog
messages when optimizers produce non-finite parameter weights. :pr:`22150`
by :user:`Christian Ritter <chritter>` and :user:`Norbert Preining <norbusan>`.

- |Enhancement| Adds :term:`get_feature_names_out` to
:class:`neural_network.BernoulliRBM`. :pr:`22248` by `Thomas Fan`_.

:mod:`sklearn.pipeline`
.......................

Expand Down
5 changes: 4 additions & 1 deletion sklearn/neural_network/_rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

from ..base import BaseEstimator
from ..base import TransformerMixin
from ..base import _ClassNamePrefixFeaturesOutMixin
from ..utils import check_random_state
from ..utils import gen_even_slices
from ..utils.extmath import safe_sparse_dot
from ..utils.extmath import log_logistic
from ..utils.validation import check_is_fitted


class BernoulliRBM(TransformerMixin, BaseEstimator):
class BernoulliRBM(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Bernoulli Restricted Boltzmann Machine (RBM).

A Restricted Boltzmann Machine with binary visible units and
Expand Down Expand Up @@ -284,6 +285,7 @@ def partial_fit(self, X, y=None):
self.random_state_.normal(0, 0.01, (self.n_components, X.shape[1])),
order="F",
)
self._n_features_out = self.components_.shape[0]
if not hasattr(self, "intercept_hidden_"):
self.intercept_hidden_ = np.zeros(
self.n_components,
Expand Down Expand Up @@ -389,6 +391,7 @@ def fit(self, X, y=None):
order="F",
dtype=X.dtype,
)
self._n_features_out = self.components_.shape[0]
self.intercept_hidden_ = np.zeros(self.n_components, dtype=X.dtype)
self.intercept_visible_ = np.zeros(X.shape[1], dtype=X.dtype)
self.h_samples_ = np.zeros((self.batch_size, self.n_components), dtype=X.dtype)
Expand Down
12 changes: 12 additions & 0 deletions sklearn/neural_network/tests/test_rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,15 @@ def test_convergence_dtype_consistency():
)
assert_allclose(rbm_64.components_, rbm_32.components_, rtol=1e-03, atol=0)
assert_allclose(rbm_64.h_samples_, rbm_32.h_samples_)


@pytest.mark.parametrize("method", ["fit", "partial_fit"])
def test_feature_names_out(method):
"""Check `get_feature_names_out` for `BernoulliRBM`."""
n_components = 10
rbm = BernoulliRBM(n_components=n_components)
getattr(rbm, method)(Xdigits)

names = rbm.get_feature_names_out()
expected_names = [f"bernoullirbm{i}" for i in range(n_components)]
assert_array_equal(expected_names, names)
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def test_pandas_column_name_consistency(estimator):
"ensemble",
"kernel_approximation",
"preprocessing",
"neural_network",
]


Expand Down