Skip to content

[WIP] New __repr__ and/or pretty printing of estimators #7618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import copy
import warnings
from uuid import uuid4

import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -121,7 +122,7 @@ def clone(estimator, safe=True):


###############################################################################
def _pprint(params, offset=0, printer=repr):
def _pprint(params, offset=0, printer=repr, cutoff=500):
"""Pretty print the dictionary 'params'

Parameters
Expand Down Expand Up @@ -150,9 +151,9 @@ def _pprint(params, offset=0, printer=repr):
# architectures and versions.
this_repr = '%s=%s' % (k, str(v))
else:
# use repr of the rest
# use printer of the rest
this_repr = '%s=%s' % (k, printer(v))
if len(this_repr) > 500:
if cutoff is not None and len(this_repr) > cutoff:
this_repr = this_repr[:300] + '...' + this_repr[-100:]
if i > 0:
if (this_line_length + len(this_repr) >= 75 or '\n' in this_repr):
Expand All @@ -171,6 +172,16 @@ def _pprint(params, offset=0, printer=repr):
return lines


def _html_repr(thing):
if hasattr(thing, "_repr_html_"):
return thing._repr_html_()
elif isinstance(thing, tuple):
return "({})".format(", ". join([_html_repr(vv) for vv in thing]))
elif isinstance(thing, list):
return "[{}]".format(", ". join([_html_repr(vv) for vv in thing]))
return repr(thing)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want IPython to handle the fallback case, you can use IPython.display.HTML(thing).



###############################################################################
class BaseEstimator(object):
"""Base class for all estimators in scikit-learn
Expand Down Expand Up @@ -284,11 +295,61 @@ def set_params(self, **params):
setattr(self, key, value)
return self

def _changed_params(self):
params = self.get_params(deep=False)
filtered_params = {}
default_params = {}
init_params = signature(self.__init__).parameters
for k, v in params.items():
if v == init_params[k].default:
default_params[k] = v
else:
filtered_params[k] = v
return filtered_params, default_params

def __repr__(self):
class_name = self.__class__.__name__
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
params = self.get_params(deep=False)
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)

def _repr_html_(self):
class_name = self.__class__.__name__
changed_params, default_params = self._changed_params()
this_id = uuid4()
js = """
<script type="text/javascript">

$(document).ready(function(){{
$("#default_params_{0}").hide();
$("#more_params_{0}").click(function(){{
$("#default_params_{0}").toggle();
}});
}});

</script>""".format(this_id)
more_params_str = ("<a id='more_params_{0}'>...</a>"
"<span id='default_params_{0}'>, {1}</span>".format(
this_id, _pprint(default_params,
printer=_html_repr,
cutoff=None))
if default_params else "")
if changed_params:
more_params_str = ", " + more_params_str
my_repr = "{}<b>{}</b>({}{})".format(js, class_name, _pprint(
changed_params, printer=_html_repr, cutoff=None), more_params_str)
if False:
print_attributes = ['classes_', 'n_outputs_']
for attr in print_attributes:
value = getattr(self, attr, None)
if value is not None:
my_repr += "\n- {0}={1}".format(attr[:-1], value)
n_features = self.get_n_features()
if n_features is not None:
my_repr += "\n- {0}={1}".format('n_features', n_features)

return my_repr

def __getstate__(self):
if type(self).__module__.startswith('sklearn.'):
return dict(self.__dict__.items(), _sklearn_version=__version__)
Expand All @@ -307,6 +368,26 @@ def __setstate__(self, state):
UserWarning)
self.__dict__.update(state)

def get_n_features(self):
"""Return number of features of a fitted estimator."""
n_features = getattr(self, 'n_features_', None)
if n_features is not None:
return n_features

components = getattr(self, 'components_', None)
if components is not None:
return components.shape[1]

coef = getattr(self, 'coef_', None)
if coef is not None:
return coef.shape[-1]

scale = getattr(self, 'scale_', None)
if scale is not None:
return scale.shape[0]

return None


###############################################################################
class ClassifierMixin(object):
Expand Down
5 changes: 5 additions & 0 deletions sklearn/feature_selection/univariate_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def fit(self, X, y):

return self

def get_n_features(self):
if hasattr(self, 'scores_'):
return self.scores_.shape[0]
return None

def _check_params(self, X, y):
pass

Expand Down
19 changes: 10 additions & 9 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def __init__(self, estimator, scoring=None,
self.scoring = scoring
self.estimator = estimator
self.n_jobs = n_jobs
self.fit_params = fit_params if fit_params is not None else {}
self.fit_params = fit_params
self.iid = iid
self.refit = refit
self.cv = cv
Expand Down Expand Up @@ -550,6 +550,7 @@ def fit(self, X, y=None, groups=None):
Group labels for the samples used while splitting the dataset into
train/test set.
"""
fit_params = self.fit_params if self.fit_params is not None else {}
estimator = self.estimator
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
Expand All @@ -572,7 +573,7 @@ def fit(self, X, y=None, groups=None):
pre_dispatch=pre_dispatch
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
fit_params=self.fit_params,
fit_params=fit_params,
return_train_score=self.return_train_score,
return_n_test_samples=True,
return_times=True, return_parameters=False,
Expand Down Expand Up @@ -655,9 +656,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
best_estimator = clone(base_estimator).set_params(
**best_parameters)
if y is not None:
best_estimator.fit(X, y, **self.fit_params)
best_estimator.fit(X, y, **fit_params)
else:
best_estimator.fit(X, **self.fit_params)
best_estimator.fit(X, **fit_params)
self.best_estimator_ = best_estimator
return self

Expand Down Expand Up @@ -808,7 +809,7 @@ class GridSearchCV(BaseSearchCV):
kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=...,
verbose=False),
fit_params={}, iid=..., n_jobs=1,
fit_params=None, iid=..., n_jobs=1,
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
scoring=..., verbose=...)
>>> sorted(clf.cv_results_.keys())
Expand Down Expand Up @@ -1159,10 +1160,10 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
self.n_iter = n_iter
self.random_state = random_state
super(RandomizedSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)

def _get_param_iterator(self):
"""Return ParameterSampler instance for the given distributions"""
Expand Down
5 changes: 5 additions & 0 deletions sklearn/neighbors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def _pairwise(self):
# For cross-validation routines to split data correctly
return self.metric == 'precomputed'

def get_n_features(self):
if getattr(self, '_fit_X', None) is not None:
return self._fit_X.shape[1]
return None


class KNeighborsMixin(object):
"""Mixin for k-neighbors searches"""
Expand Down
3 changes: 2 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,8 @@ def check_no_fit_attributes_set_in_init(name, Estimator):
"""Check that Estimator.__init__ doesn't set trailing-_ attributes."""
estimator = Estimator()
for attr in dir(estimator):
if attr.endswith("_") and not attr.startswith("__"):
if (attr.endswith("_") and not attr.startswith("__") and not
hasattr(Estimator, attr)):
# This check is for properties, they can be listed in dir
# while at the same time have hasattr return False as long
# as the property getter raises an AttributeError
Expand Down