-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Add n_features_in_ attribute to BaseEstimator #13603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7d9dcc4
f117745
95b330c
e56592b
3bdcb5c
8ecc690
60e4cea
ff19f22
a44318b
42249fb
abdc94e
a50e76f
62fc42e
6845788
3246436
ee2598b
6a14e4b
3f2d44f
25fda0f
9bdfb65
be76ef4
b464f86
70dc4ed
988f9c4
fd9b72c
4f3d6ff
08f7192
5a41275
f0e7b41
193fda1
5b20a4c
a49e5ea
968fbff
e4faf13
908aea6
a88a4c5
f3fb539
4b7b758
53027d3
c5dfbbd
a1aea70
9ecc396
e9c3104
9292c84
e11b0bb
6846bea
60c5108
615140e
fe052e6
9a205dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
|
||
from . import __version__ | ||
from .utils import _IS_32BIT | ||
from .utils.validation import check_X_y | ||
from .utils.validation import check_array | ||
|
||
_DEFAULT_TAGS = { | ||
'non_deterministic': False, | ||
|
@@ -323,6 +325,31 @@ def _get_tags(self): | |
collected_tags.update(more_tags) | ||
return collected_tags | ||
|
||
def _validate_n_features(self, X, check_n_features): | ||
if check_n_features: | ||
if not hasattr(self, 'n_features_in_'): | ||
raise RuntimeError( | ||
"check_n_features is True but there is no n_features_in_ " | ||
"attribute." | ||
) | ||
if X.shape[1] != self.n_features_in_: | ||
raise ValueError( | ||
'X has {} features, but this {} is expecting {} features ' | ||
'as input.'.format(X.shape[1], self.__class__.__name__, | ||
self.n_features_in_) | ||
) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
self.n_features_in_ = X.shape[1] | ||
|
||
def _validate_X(self, X, check_n_features=False, **check_array_params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does using Is there an issue with making all parameters explicit here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand your concern. I can definitely make all the parameters explicit. The only downside is that we have to keep the signature synchronized with that of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would only mean the user hasn't properly specified the minimum sklearn version requirement properly. I don't think that's something we should worry about. |
||
X = check_array(X, **check_array_params) | ||
self._validate_n_features(X, check_n_features) | ||
return X | ||
|
||
def _validate_X_y(self, X, y, check_n_features=False, **check_X_y_params): | ||
X, y = check_X_y(X, y, **check_X_y_params) | ||
self._validate_n_features(X, check_n_features) | ||
return X, y | ||
|
||
class ClassifierMixin: | ||
"""Mixin class for all classifiers in scikit-learn.""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.