Skip to content

[MRG] Add pprint for estimators - continued #11705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Dec 20, 2018

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Jul 29, 2018

Reference Issues/PRs

This is the continuation of #9099 that aims to have better printing representations for estimators

Close #9099
Close #9039
Close #7618
Fix #6323

What does this implement/fix? Explain your changes.

  • Fixed some line length issues
  • Factored code for Estimator and Pipeline into a single method
  • Completely refactored code to use extend the built-in PrettyPrinter class
  • Added some tests
  • Added some doc

Any other comments?

I marked this as WIP but a first review may be useful. I tried to make the code easier to understand but it's always a bit hard with those kind of functionalities.

Line length limit is still not perfect, see e.g. the test with RFE(...) that goes beyond 80 chars because the printed estimator is actually an argument.

Questions

  • Should the file _pprint.py be in utils instead of at the root? (test file would follow)
  • Should I remove the _pprint() method in base.py and use the Formatter class in __repr__ instead?

@jnothman
Copy link
Member

jnothman commented Jul 31, 2018 via email

@amueller
Copy link
Member

can you add a test with a pipeline inside a grid-search please? And the pipeline maybe has a ColumnTransformer? ;)

because callable() returns True also for class objects (which we want to
reprensent with their name as well anyway)
@jnothman
Copy link
Member

jnothman commented Sep 4, 2018 via email

@NicolasHug
Copy link
Member Author

NicolasHug commented Sep 11, 2018

So due to some limitations of the current method, I completely changed the implementation and relied on pprint.PrettyPrinter instead.

Here are some output examples:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, max_iter=100, multi_class='ovr',
                   n_jobs=1, penalty='l2', random_state=None,
                   solver='liblinear', tol=0.0001, verbose=0, warm_start=False)

Pipeline(memory=None,
         steps=['reduce_dim': PCA(copy=True, iterated_power='auto',
                                  n_components=None, random_state=None,
                                  svd_solver='auto', tol=0.0, whiten=False),
                'classify': LinearSVC(C=1.0, class_weight=None, dual=True,
                                      fit_intercept=True, intercept_scaling=1,
                                      loss='squared_hinge', max_iter=1000,
                                      multi_class='ovr', penalty='l2',
                                      random_state=None, tol=0.0001,
                                      verbose=0)])

RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=LogisticRegression(C=1.0,
                                                                                                                     class_weight=None,
                                                                                                                     dual=False,
                                                                                                                     fit_intercept=True,
                                                                                                                     intercept_scaling=1,
                                                                                                                     max_iter=100,
                                                                                                                     multi_class='ovr',
                                                                                                                     n_jobs=1,
                                                                                                                     penalty='l2',
                                                                                                                     random_state=None,
                                                                                                                     solver='liblinear',
                                                                                                                     tol=0.0001,
                                                                                                                     verbose=0,
                                                                                                                     warm_start=False),
                                                                                        n_features_to_select=None,
                                                                                        step=1,
                                                                                        verbose=0),
                                                                          n_features_to_select=None,
                                                                          step=1,
                                                                          verbose=0),
                                                            n_features_to_select=None,
                                                            step=1, verbose=0),
                                              n_features_to_select=None, step=1,
                                              verbose=0),
                                n_features_to_select=None, step=1, verbose=0),
                  n_features_to_select=None, step=1, verbose=0),
    n_features_to_select=None, step=1, verbose=0)

GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
                           decision_function_shape='ovr', degree=3,
                           gamma='auto_deprecated', kernel='rbf', max_iter=-1,
                           probability=False, random_state=None, shrinking=True,
                           tol=0.001, verbose=False),
             fit_params=None, iid='warn', n_jobs=1,
             param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                          'kernel': ['rbf']},
                         {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
             pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
             scoring=None, verbose=0)

GridSearchCV(cv=3, error_score='raise-deprecating',
             estimator=Pipeline(memory=None,
                                steps=['reduce_dim': PCA(copy=True,
                                                         iterated_power='auto',
                                                         n_components=None,
                                                         random_state=None,
                                                         svd_solver='auto',
                                                         tol=0.0,
                                                         whiten=False),
                                       'classify': LinearSVC(C=1.0,
                                                             class_weight=None,
                                                             dual=True,
                                                             fit_intercept=True,
                                                             intercept_scaling=1,
                                                             loss='squared_hinge',
                                                             max_iter=1000,
                                                             multi_class='ovr',
                                                             penalty='l2',
                                                             random_state=None,
                                                             tol=0.0001,
                                                             verbose=0)]),
             fit_params=None, iid='warn', n_jobs=1,
             param_grid=[{'classify__C': [1, 10, 100, 1000],
                          'reduce_dim': [PCA(copy=True, iterated_power=7,
                                             n_components=None,
                                             random_state=None,
                                             svd_solver='auto', tol=0.0,
                                             whiten=False),
                                         NMF(alpha=0.0, beta_loss='frobenius',
                                             init=None, l1_ratio=0.0,
                                             max_iter=200, n_components=None,
                                             random_state=None, shuffle=False,
                                             solver='cd', tol=0.0001,
                                             verbose=0)],
                          'reduce_dim__n_components': [2, 4, 8]},
                         {'classify__C': [1, 10, 100, 1000],
                          'reduce_dim': [SelectKBest(k=10,
                                                     score_func=<function chi2 at 0x7f77d68c18c8>)],
                          'reduce_dim__k': [2, 4, 8]}],
             pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
             scoring=None, verbose=0)

The default is to indent after the name of the estimator, but it's expensive in line width. Unfortunately, PrettyPrinter does not really support other forms of indentations:

The amount of indentation added for each recursive level is specified by indent; the default is one. Other values can cause output to look a little odd

From the docs. Example of discrepancy:

pp = PrettyPrinter(indent=4, compact=True) 
pp.pprint(list(range(30))) 

will give

[   0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
    21, 22, 23, 24, 25, 26, 27, 28, 29]

Here is what the output looks like on the same examples when indenting with 4 characters:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
    intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
    penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0,
    warm_start=False)

Pipeline(memory=None,
    steps=[   'reduce_dim': PCA(copy=True, iterated_power='auto',
                                n_components=None, random_state=None,
                                svd_solver='auto', tol=0.0, whiten=False),
              'classify': LinearSVC(C=1.0, class_weight=None, dual=True,
                              fit_intercept=True, intercept_scaling=1,
                              loss='squared_hinge', max_iter=1000,
                              multi_class='ovr', penalty='l2',
                              random_state=None, tol=0.0001, verbose=0)])

RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=LogisticRegression(C=1.0,
                                                                                                      class_weight=None,
                                                                                                      dual=False,
                                                                                                      fit_intercept=True,
                                                                                                      intercept_scaling=1,
                                                                                                      max_iter=100,
                                                                                                      multi_class='ovr',
                                                                                                      n_jobs=1,
                                                                                                      penalty='l2',
                                                                                                      random_state=None,
                                                                                                      solver='liblinear',
                                                                                                      tol=0.0001,
                                                                                                      verbose=0,
                                                                                                      warm_start=False),
                                                                                        n_features_to_select=None,
                                                                                        step=1,
                                                                                        verbose=0),
                                                                          n_features_to_select=None,
                                                                          step=1,
                                                                          verbose=0),
                                                            n_features_to_select=None,
                                                            step=1, verbose=0),
                                              n_features_to_select=None, step=1,
                                              verbose=0),
                                n_features_to_select=None, step=1, verbose=0),
                  n_features_to_select=None, step=1, verbose=0),
    n_features_to_select=None, step=1, verbose=0)

GridSearchCV(cv=5, error_score='raise-deprecating',
    estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
                  decision_function_shape='ovr', degree=3,
                  gamma='auto_deprecated', kernel='rbf', max_iter=-1,
                  probability=False, random_state=None, shrinking=True,
                  tol=0.001, verbose=False),
    fit_params=None, iid='warn', n_jobs=1,
    param_grid=[   {   'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                       'kernel': ['rbf']},
                   {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
    pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
    scoring=None, verbose=0)

GridSearchCV(cv=3, error_score='raise-deprecating',
    estimator=Pipeline(memory=None,
                  steps=[   'reduce_dim': PCA(copy=True, iterated_power='auto',
                                              n_components=None,
                                              random_state=None,
                                              svd_solver='auto', tol=0.0,
                                              whiten=False),
                            'classify': LinearSVC(C=1.0, class_weight=None, dual=True,
                                            fit_intercept=True,
                                            intercept_scaling=1,
                                            loss='squared_hinge', max_iter=1000,
                                            multi_class='ovr', penalty='l2',
                                            random_state=None, tol=0.0001,
                                            verbose=0)]),
    fit_params=None, iid='warn', n_jobs=1,
    param_grid=[   {   'classify__C': [1, 10, 100, 1000],
                       'reduce_dim': [   PCA(copy=True, iterated_power=7,
                                             n_components=None,
                                             random_state=None,
                                             svd_solver='auto', tol=0.0,
                                             whiten=False),
                                         NMF(alpha=0.0, beta_loss='frobenius',
                                             init=None, l1_ratio=0.0,
                                             max_iter=200, n_components=None,
                                             random_state=None, shuffle=False,
                                             solver='cd', tol=0.0001,
                                             verbose=0)],
                       'reduce_dim__n_components': [2, 4, 8]},
                   {   'classify__C': [1, 10, 100, 1000],
                       'reduce_dim': [   SelectKBest(k=10,
                                             score_func=<function chi2 at 0x7f4a124bd8c8>)],
                       'reduce_dim__k': [2, 4, 8]}],
    pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
    scoring=None, verbose=0)

I personally like the first output. It might render quite long strings when estimators are nested but I don't think that it's a major issue, and it's very easy to read.

ping @amueller

@jnothman
Copy link
Member

Yes, I much prefer the first. Good idea to reuse PrettyPrinter. Will that be configurable for us wanting to hide default parameter settings, etc?

@NicolasHug
Copy link
Member Author

Everything is encapsulated in BaseEstimator.__repr__() so we can hide the default parameters, if that what you mean?

I noticed that the steps is formatted as a list of key: values instead of a list of tuples. Is it OK like this or do we want to change it? I guess it's slightly clearer this way but it prevents the string from being evaluated with eval for example.

@jnothman
Copy link
Member

jnothman commented Sep 13, 2018 via email

@NicolasHug
Copy link
Member Author

I fixed it, here is the new look:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, max_iter=100, multi_class='ovr',
                   n_jobs=1, penalty='l2', random_state=None,
                   solver='liblinear', tol=0.0001, verbose=0, warm_start=False)

Pipeline(memory=None,
         steps=[('reduce_dim',
                 PCA(copy=True, iterated_power='auto', n_components=None,
                     random_state=None, svd_solver='auto', tol=0.0,
                     whiten=False)),
                ('classify',
                 LinearSVC(C=1.0, class_weight=None, dual=True,
                           fit_intercept=True, intercept_scaling=1,
                           loss='squared_hinge', max_iter=1000,
                           multi_class='ovr', penalty='l2', random_state=None,
                           tol=0.0001, verbose=0))])

RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=RFE(estimator=LogisticRegression(C=1.0,
                                                                                                                     class_weight=None,
                                                                                                                     dual=False,
                                                                                                                     fit_intercept=True,
                                                                                                                     intercept_scaling=1,
                                                                                                                     max_iter=100,
                                                                                                                     multi_class='ovr',
                                                                                                                     n_jobs=1,
                                                                                                                     penalty='l2',
                                                                                                                     random_state=None,
                                                                                                                     solver='liblinear',
                                                                                                                     tol=0.0001,
                                                                                                                     verbose=0,
                                                                                                                     warm_start=False),
                                                                                        n_features_to_select=None,
                                                                                        step=1,
                                                                                        verbose=0),
                                                                          n_features_to_select=None,
                                                                          step=1,
                                                                          verbose=0),
                                                            n_features_to_select=None,
                                                            step=1, verbose=0),
                                              n_features_to_select=None, step=1,
                                              verbose=0),
                                n_features_to_select=None, step=1, verbose=0),
                  n_features_to_select=None, step=1, verbose=0),
    n_features_to_select=None, step=1, verbose=0)

GridSearchCV(cv=5, error_score='raise-deprecating',
             estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
                           decision_function_shape='ovr', degree=3,
                           gamma='auto_deprecated', kernel='rbf', max_iter=-1,
                           probability=False, random_state=None, shrinking=True,
                           tol=0.001, verbose=False),
             fit_params=None, iid='warn', n_jobs=1,
             param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                          'kernel': ['rbf']},
                         {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
             pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
             scoring=None, verbose=0)

GridSearchCV(cv=3, error_score='raise-deprecating',
             estimator=Pipeline(memory=None,
                                steps=[('reduce_dim',
                                        PCA(copy=True, iterated_power='auto',
                                            n_components=None,
                                            random_state=None,
                                            svd_solver='auto', tol=0.0,
                                            whiten=False)),
                                       ('classify',
                                        LinearSVC(C=1.0, class_weight=None,
                                                  dual=True, fit_intercept=True,
                                                  intercept_scaling=1,
                                                  loss='squared_hinge',
                                                  max_iter=1000,
                                                  multi_class='ovr',
                                                  penalty='l2',
                                                  random_state=None, tol=0.0001,
                                                  verbose=0))]),
             fit_params=None, iid='warn', n_jobs=1,
             param_grid=[{'classify__C': [1, 10, 100, 1000],
                          'reduce_dim': [PCA(copy=True, iterated_power=7,
                                             n_components=None,
                                             random_state=None,
                                             svd_solver='auto', tol=0.0,
                                             whiten=False),
                                         NMF(alpha=0.0, beta_loss='frobenius',
                                             init=None, l1_ratio=0.0,
                                             max_iter=200, n_components=None,
                                             random_state=None, shuffle=False,
                                             solver='cd', tol=0.0001,
                                             verbose=0)],
                          'reduce_dim__n_components': [2, 4, 8]},
                         {'classify__C': [1, 10, 100, 1000],
                          'reduce_dim': [SelectKBest(k=10,
                                                     score_func=<function chi2 at 0x7f046321a950>)],
                          'reduce_dim__k': [2, 4, 8]}],
             pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
             scoring=None, verbose=0)

I'll put this as MRG for reviews, and if it gets approved I'll start updating the docs with the new formatting to fix the tests.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've realised the change is a bit weak on documentation. Not sure where to put it in the user guide, though. Or whether print_changed_only should be illustrated in the example gallery.

Otherwise, I'm keen to see this road tested.

@@ -191,6 +191,10 @@ Support for Python 3.4 and below has been officially dropped.
Multiple modules
................

- The `__repr__()` method of all estimators (used when calling
`print(estimator)`) has been entirely re-written. :issue:`11705` by
:user:`Nicolas Hug <NicolasHug>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_changed_only needs to be mentioned in what's new. I think it should also be mentioned somewhere in the documentation more explicitly.

@@ -191,6 +191,10 @@ Support for Python 3.4 and below has been officially dropped.
Multiple modules
................

- The `__repr__()` method of all estimators (used when calling
`print(estimator)`) has been entirely re-written. :issue:`11705` by
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"building on Python's pretty printing standard library."

@@ -185,12 +185,19 @@ Support for Python 3.4 and below has been officially dropped.
``max_depth`` by 1 while expanding the tree if ``max_leaf_nodes`` and
``max_depth`` were both specified by the user. Please note that this also
affects all ensemble methods using decision trees.
:pr:`12344` by :user:`Adrin Jalali <adrinjalali>`.
:issue:`12344` by :user:`Adrin Jalali <adrinjalali>`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali I took the liberty to fix this

@NicolasHug
Copy link
Member Author

NicolasHug commented Dec 18, 2018

I've added an example in set_config(). There's one in config_context() about assume_finite so I figured why not... But it's not necessarily ideal.

EDIT: well, actually, I removed it. It broke the other doctests. Could have made it work be setting again set_config(print_changed_only=False) but I guess that's an indication that this example does not belong here.

@jnothman
Copy link
Member

jnothman commented Dec 19, 2018 via email

@NicolasHug
Copy link
Member Author

Where do I reference this example from in the doc?

@amueller
Copy link
Member

link it from the set_config docs maybe? I don't think there's a single section on that in the user guide, though...

@NicolasHug
Copy link
Member Author

Isn't it automatically linked?

@amueller
Copy link
Member

to the API docs, yes.

@NicolasHug
Copy link
Member Author

set_config is referenced in in doc/modules/computing.rst but I doubt it makes sense there

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User guide could consider adding a section on working with scikit-learn in a terminal? But not in this PR.

@jnothman
Copy link
Member

Note that the example is automatically referenced at the bottom of https://41116-843222-gh.circle-artifacts.com/0/doc/modules/generated/sklearn.set_config.html

@amueller amueller merged commit a50c03f into scikit-learn:master Dec 20, 2018
@jnothman
Copy link
Member

jnothman commented Dec 20, 2018 via email

@GaelVaroquaux
Copy link
Member

Hurray! Thanks heaps!! Sorry for not giving feedback on this. Anyhow, I love it.

Any plans to add back the color work, that @amueller had originally implemented? If it could be enabled by default, it would help usability.

@amueller
Copy link
Member

@GaelVaroquaux I'd love to see that, but I figured "small steps" would be better. This is already a big improvement imho.
I guess one question is whether we want something jupyter specific or something that also works on the console. I would probably go with jupyter and use html.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Dec 20, 2018 via email

@NicolasHug NicolasHug deleted the pr/9099 branch December 20, 2018 20:42
adrinjalali pushed a commit to adrinjalali/scikit-learn that referenced this pull request Jan 7, 2019
* add pprint for estimators

* strip color from length, add color option

* Minor cleaning, fixes, factoring and docs

* Added some basic tests

* Fixed line length issue

* fixed flake8 and added visual test for review

* Fixed test

* Fixed Python 2 issues (inspect.signature import)

* Trying to fix flake8 again

* Added special repr for functions

* Added some other visual tests

* Changed _format_function in to _format_callable

because callable() returns True also for class objects (which we want to
reprensent with their name as well anyway)

* Consistent output in Python 2 and 3

* WIP

* Now using the builtin pprint module

* pep8

* Added changed_only param

* Fixed printing when string would fit in less than line width

* Fixed printing of steps parameter

* Fixed changed_only param for short estimators

* fixed pep8

* Added some more description in docstring

* changed_only is now an option from set_config()

* Put _pprint.py into sklearn/utils, added tests

* Added doctest NORMALIZE_WHITESPACE where needed

* Fixed tests

* fix test-doc

* fixing test that passed before....

* Fixed tests

* Added test for changed_only and long lines

* typo

* Added authors names

* Added license file

* Added ellipsis based on number of elements in sequence + added increasinly aggressive repr strategies

* Updated whatsnew

* dont use increaingly aggressive strategy

* Fixed tests

* Removed LICENSE file and put license text in _pprint.py

* fixed test_base

* Sorted parameters dictionary for consistent output in 3.5

* Actually using OrderedDict...

* Addressed comments

* Added test for NaN changed parameter

* Update whatsnew

* Added example to set_config()

* Removed example

* Added example in gallery

* Spelling
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
* add pprint for estimators

* strip color from length, add color option

* Minor cleaning, fixes, factoring and docs

* Added some basic tests

* Fixed line length issue

* fixed flake8 and added visual test for review

* Fixed test

* Fixed Python 2 issues (inspect.signature import)

* Trying to fix flake8 again

* Added special repr for functions

* Added some other visual tests

* Changed _format_function in to _format_callable

because callable() returns True also for class objects (which we want to
reprensent with their name as well anyway)

* Consistent output in Python 2 and 3

* WIP

* Now using the builtin pprint module

* pep8

* Added changed_only param

* Fixed printing when string would fit in less than line width

* Fixed printing of steps parameter

* Fixed changed_only param for short estimators

* fixed pep8

* Added some more description in docstring

* changed_only is now an option from set_config()

* Put _pprint.py into sklearn/utils, added tests

* Added doctest NORMALIZE_WHITESPACE where needed

* Fixed tests

* fix test-doc

* fixing test that passed before....

* Fixed tests

* Added test for changed_only and long lines

* typo

* Added authors names

* Added license file

* Added ellipsis based on number of elements in sequence + added increasinly aggressive repr strategies

* Updated whatsnew

* dont use increaingly aggressive strategy

* Fixed tests

* Removed LICENSE file and put license text in _pprint.py

* fixed test_base

* Sorted parameters dictionary for consistent output in 3.5

* Actually using OrderedDict...

* Addressed comments

* Added test for NaN changed parameter

* Update whatsnew

* Added example to set_config()

* Removed example

* Added example in gallery

* Spelling
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
* add pprint for estimators

* strip color from length, add color option

* Minor cleaning, fixes, factoring and docs

* Added some basic tests

* Fixed line length issue

* fixed flake8 and added visual test for review

* Fixed test

* Fixed Python 2 issues (inspect.signature import)

* Trying to fix flake8 again

* Added special repr for functions

* Added some other visual tests

* Changed _format_function in to _format_callable

because callable() returns True also for class objects (which we want to
reprensent with their name as well anyway)

* Consistent output in Python 2 and 3

* WIP

* Now using the builtin pprint module

* pep8

* Added changed_only param

* Fixed printing when string would fit in less than line width

* Fixed printing of steps parameter

* Fixed changed_only param for short estimators

* fixed pep8

* Added some more description in docstring

* changed_only is now an option from set_config()

* Put _pprint.py into sklearn/utils, added tests

* Added doctest NORMALIZE_WHITESPACE where needed

* Fixed tests

* fix test-doc

* fixing test that passed before....

* Fixed tests

* Added test for changed_only and long lines

* typo

* Added authors names

* Added license file

* Added ellipsis based on number of elements in sequence + added increasinly aggressive repr strategies

* Updated whatsnew

* dont use increaingly aggressive strategy

* Fixed tests

* Removed LICENSE file and put license text in _pprint.py

* fixed test_base

* Sorted parameters dictionary for consistent output in 3.5

* Actually using OrderedDict...

* Addressed comments

* Added test for NaN changed parameter

* Update whatsnew

* Added example to set_config()

* Removed example

* Added example in gallery

* Spelling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

__repr__ not that helpful
5 participants