Skip to content

Commit 7ee31d0

Browse files
committed
Fixes #73, fix __repr__ for ClassifierKMeans
1 parent 3b0a276 commit 7ee31d0

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

_unittests/ut_mlmodel/test_classification_kmeans.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
@brief test log(time=2s)
3+
@brief test log(time=20s)
44
"""
55
import unittest
66
import numpy
77
from numpy.random import RandomState
88
from sklearn import datasets
99
from pyquickhelper.pycode import ExtTestCase
10-
from mlinsights.mlmodel import ClassifierAfterKMeans
11-
from mlinsights.mlmodel import test_sklearn_pickle, test_sklearn_clone, test_sklearn_grid_search_cv
10+
from mlinsights.mlmodel import (
11+
ClassifierAfterKMeans, test_sklearn_pickle, test_sklearn_clone,
12+
test_sklearn_grid_search_cv
13+
)
1214

1315

1416
class TestClassifierAfterKMeans(ExtTestCase):
@@ -73,6 +75,16 @@ def test_classification_kmeans_relevance(self):
7375
score = clk.score(X, Y)
7476
self.assertGreater(score, 0.95)
7577

78+
def test_issue(self):
79+
X, labels_true = datasets.make_blobs(
80+
n_samples=750, centers=6, cluster_std=0.4)
81+
labels_true = labels_true % 3
82+
clcl = ClassifierAfterKMeans(e_max_iter=1000)
83+
clcl.fit(X, labels_true)
84+
r = repr(clcl)
85+
self.assertIn('ClassifierAfterKMeans(', r)
86+
self.assertIn("c_init='k-means++'", r)
87+
7688

7789
if __name__ == "__main__":
7890
unittest.main()

_unittests/ut_mlmodel/test_kmeans_l1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
@brief test log(time=2s)
3+
@brief test log(time=6s)
44
"""
55
import unittest
66
import numpy

_unittests/ut_mlmodel/test_kmeans_sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
@brief test log(time=2s)
3+
@brief test log(time=5s)
44
"""
55
import unittest
66
import numpy as np

mlinsights/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
@brief Module *mlinsights*.
55
Look for insights for machine learned models.
66
"""
7-
__version__ = "0.2.420"
7+
__version__ = "0.2.422"
88
__author__ = "Xavier Dupré"
99
__github__ = "https://github.com/sdpython/mlinsights"
1010
__url__ = "http://www.xavierdupre.fr/app/mlinsights/helpsphinx/index.html"

mlinsights/mlmodel/classification_kmeans.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
@file
33
@brief Combines a *k-means* followed by a predictor.
44
"""
5+
import textwrap
56
import inspect
67
import numpy
78
from sklearn.linear_model import LogisticRegression
@@ -134,7 +135,8 @@ def get_params(self, deep=True):
134135
@param deep unused here
135136
@return dict
136137
137-
:meth:`set_params <mlinsights.mlmodel.classification_kmeans.ClassifierAfterKMeans.set_params>`
138+
:meth:`set_params <mlinsights.mlmodel.classification_kmeans.
139+
ClassifierAfterKMeans.set_params>`
138140
describes the pattern parameters names follow.
139141
"""
140142
res = {}
@@ -164,3 +166,14 @@ def set_params(self, **values):
164166
raise ValueError("Unexpected parameter name '{0}'".format(k))
165167
self.clus.set_params(**pc)
166168
self.estimator.set_params(**pe)
169+
170+
def __repr__(self):
171+
"""
172+
Overloads `repr` as *scikit-learn* now relies
173+
on the constructor signature.
174+
"""
175+
el = ', '.join(['%s=%r' % (k, v)
176+
for k, v in self.get_params().items()])
177+
text = "%s(%s)" % (self.__class__.__name__, el)
178+
lines = textwrap.wrap(text, subsequent_indent=' ')
179+
return "\n".join(lines)

0 commit comments

Comments
 (0)