From 0fa60b76fe917a7d09ff0e69756200ec7501448f Mon Sep 17 00:00:00 2001 From: John Moeller Date: Thu, 16 Jun 2016 22:49:36 -0600 Subject: [PATCH 01/16] Make KernelCenterer a _pairwise operation Replicate solution to https://github.com/scikit-learn/scikit-learn/commit/9a520779c233dfeff466870c0b7cb04b705e61af except that `_pairwise` should always be `True` for `KernelCenterer` because it's supposed to receive a Gram matrix. This should make `KernelCenterer` usable in `Pipeline`s. Happy to add tests, just tell me what should be covered. --- sklearn/preprocessing/data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 42957133b654c..74099dc51b153 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1584,6 +1584,10 @@ def transform(self, K, y=None, copy=True): K += self.K_fit_all_ return K + + @property + def _pairwise(self): + return True def add_dummy_feature(X, value=1.0): From 00438850977a677d8dd71bab4257cfe83817894f Mon Sep 17 00:00:00 2001 From: John Moeller Date: Fri, 17 Jun 2016 17:28:36 -0600 Subject: [PATCH 02/16] Adding test for PR #6900 --- sklearn/preprocessing/tests/test_data.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index b1ef18a8ebc45..9caa1818fcf66 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -52,6 +52,11 @@ from sklearn.preprocessing.data import PolynomialFeatures from sklearn.exceptions import DataConversionWarning +from sklearn.pipeline import Pipeline +from sklearn.cross_validation import cross_val_score +from sklearn.cross_validation import LeaveOneOut +from sklearn.svm import SVR + from sklearn import datasets iris = datasets.load_iris() @@ -1369,6 +1374,23 @@ def test_center_kernel(): K_pred_centered2 = centerer.transform(K_pred) assert_array_almost_equal(K_pred_centered, K_pred_centered2) +def test_cv_pipeline_precomputed(): + """Cross-validate a regression on four coplanar points with the same + value. Use precomputed kernel to ensure Pipeline with KernelCenterer + is treated as a _pairwise operation.""" + X = np.array([[3,0,0],[0,3,0],[0,0,3],[1,1,1]]) + y = np.ones((4,)) + K = X.dot(X.T) + kcent = KernelCenterer() + pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())]) + + # did the pipeline set the _pairwise attribute? + assert_true(pipeline._pairwise) + + # test cross-validation, score should be almost perfect + score = cross_val_score(pipeline,K,y,cv=LeaveOneOut(4)) + assert_array_almost_equal(score, np.ones_like(score)) + def test_fit_transform(): rng = np.random.RandomState(0) From 069336ebf592fcb6c49c17ec5aee9150d78d9ca3 Mon Sep 17 00:00:00 2001 From: John Moeller Date: Fri, 17 Jun 2016 17:44:51 -0600 Subject: [PATCH 03/16] Simplifying imports and test --- sklearn/preprocessing/tests/test_data.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 9caa1818fcf66..e2e080562400d 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -53,8 +53,7 @@ from sklearn.exceptions import DataConversionWarning from sklearn.pipeline import Pipeline -from sklearn.cross_validation import cross_val_score -from sklearn.cross_validation import LeaveOneOut +from sklearn.cross_validation import cross_val_predict from sklearn.svm import SVR from sklearn import datasets @@ -1379,7 +1378,7 @@ def test_cv_pipeline_precomputed(): value. Use precomputed kernel to ensure Pipeline with KernelCenterer is treated as a _pairwise operation.""" X = np.array([[3,0,0],[0,3,0],[0,0,3],[1,1,1]]) - y = np.ones((4,)) + y_true = np.ones((4,)) K = X.dot(X.T) kcent = KernelCenterer() pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())]) @@ -1388,8 +1387,10 @@ def test_cv_pipeline_precomputed(): assert_true(pipeline._pairwise) # test cross-validation, score should be almost perfect - score = cross_val_score(pipeline,K,y,cv=LeaveOneOut(4)) - assert_array_almost_equal(score, np.ones_like(score)) + # NB: this test is pretty vacuous -- it's mainly to test integration + # of Pipeline and KernelCenterer + y_pred = cross_val_predict(pipeline,K,y_true,cv=4) + assert_array_almost_equal(y_true, y_pred) def test_fit_transform(): From 039b6f347f8cef1785a10577b60ae46d4dd7e0cb Mon Sep 17 00:00:00 2001 From: Jeremy Hintz Date: Sat, 18 Jun 2016 03:03:36 -0700 Subject: [PATCH 04/16] updating changelog links on homepage (#6901) --- doc/index.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index e03dd07b4c927..da5ff17de7d3f 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -205,13 +205,13 @@

News

  • On-going development: - What's new (Changelog) + What's new (Changelog)
  • -
  • November 2015. scikit-learn 0.17.0 is available for download (Changelog). +
  • November 2015. scikit-learn 0.17.0 is available for download (Changelog).
  • -
  • March 2015. scikit-learn 0.16.0 is available for download (Changelog). +
  • March 2015. scikit-learn 0.16.0 is available for download (Changelog).
  • -
  • July 2014. scikit-learn 0.15.0 is available for download (Changelog). +
  • July 2014. scikit-learn 0.15.0 is available for download (Changelog).
  • July 14-20th, 2014: international sprint. During this week-long sprint, we gathered 18 of the core @@ -225,7 +225,7 @@ Inria, and tinyclues.
  • -
  • August 2013. scikit-learn 0.14 is available for download (Changelog). +
  • August 2013. scikit-learn 0.14 is available for download (Changelog).
From f69fb7e3b5c59fa611f3a15c5be753a5d9e21517 Mon Sep 17 00:00:00 2001 From: hashcode55 Date: Sun, 19 Jun 2016 02:53:34 +0530 Subject: [PATCH 05/16] first commit --- doc/datasets/twenty_newsgroups.rst | 8 ++++---- examples/model_selection/grid_search_digits.py | 2 +- sklearn/metrics/tests/test_classification.py | 8 ++++---- sklearn/svm/tests/test_svm.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst index 01c2a53ff77e5..e0e845a04f539 100644 --- a/doc/datasets/twenty_newsgroups.rst +++ b/doc/datasets/twenty_newsgroups.rst @@ -132,8 +132,8 @@ which is fast to train and achieves a decent F-score:: >>> clf = MultinomialNB(alpha=.01) >>> clf.fit(vectors, newsgroups_train.target) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted') - 0.88251152461278892 + >>> metrics.f1_score(newsgroups_test.target, pred, average='macro') + 0.88213592402729568 (The example :ref:`example_text_document_classification_20newsgroups.py` shuffles the training and test data, instead of segmenting by time, and in that case @@ -182,8 +182,8 @@ blocks, and quotation blocks respectively. ... categories=categories) >>> vectors_test = vectorizer.transform(newsgroups_test.data) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(pred, newsgroups_test.target, average='weighted') - 0.78409163025839435 + >>> metrics.f1_score(pred, newsgroups_test.target, average='macro') + 0.77310350681274775 This classifier lost over a lot of its F-score, just because we removed metadata that has little to do with topic classification. diff --git a/examples/model_selection/grid_search_digits.py b/examples/model_selection/grid_search_digits.py index 40ed573247efd..13755b0bc8c10 100644 --- a/examples/model_selection/grid_search_digits.py +++ b/examples/model_selection/grid_search_digits.py @@ -51,7 +51,7 @@ print() clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, - scoring='%s_weighted' % score) + scoring='%s_macro' % score) clf.fit(X_train, y_train) print("Best parameters set found on development set:") diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index f28e6cc77093b..2794948ce93d5 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -469,7 +469,7 @@ def test_precision_recall_f1_score_multiclass_pos_label_none(): # compute scores with default labels introspection p, r, f, s = precision_recall_fscore_support(y_true, y_pred, pos_label=None, - average='weighted') + average='binary') def test_zero_precision_recall(): @@ -482,10 +482,10 @@ def test_zero_precision_recall(): y_pred = np.array([2, 0, 1, 1, 2, 0]) assert_almost_equal(precision_score(y_true, y_pred, - average='weighted'), 0.0, 2) - assert_almost_equal(recall_score(y_true, y_pred, average='weighted'), + average='macro'), 0.0, 2) + assert_almost_equal(recall_score(y_true, y_pred, average='macro'), 0.0, 2) - assert_almost_equal(f1_score(y_true, y_pred, average='weighted'), + assert_almost_equal(f1_score(y_true, y_pred, average='macro'), 0.0, 2) finally: diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 670180695e452..df9f6f988c0c5 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -439,9 +439,9 @@ def test_auto_weight(): y_pred = clf.fit(X[unbalanced], y[unbalanced]).predict(X) clf.set_params(class_weight='balanced') y_pred_balanced = clf.fit(X[unbalanced], y[unbalanced],).predict(X) - assert_true(metrics.f1_score(y, y_pred, average='weighted') + assert_true(metrics.f1_score(y, y_pred, average='macro') <= metrics.f1_score(y, y_pred_balanced, - average='weighted')) + average='macro')) def test_bad_input(): From 2d7929df374cbbf464798e991263bb60f74fbb79 Mon Sep 17 00:00:00 2001 From: hashcode55 Date: Sun, 19 Jun 2016 14:40:48 +0530 Subject: [PATCH 06/16] changed binary average back to macro --- doc/datasets/twenty_newsgroups.rst | 4 ++-- sklearn/metrics/tests/test_classification.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst index e0e845a04f539..55fe227682190 100644 --- a/doc/datasets/twenty_newsgroups.rst +++ b/doc/datasets/twenty_newsgroups.rst @@ -197,8 +197,8 @@ It loses even more if we also strip this metadata from the training data: >>> clf.fit(vectors, newsgroups_train.target) >>> vectors_test = vectorizer.transform(newsgroups_test.data) >>> pred = clf.predict(vectors_test) - >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted') - 0.73160869205141166 + >>> metrics.f1_score(newsgroups_test.target, pred, average='macro') + 0.65437545099490202 Some other classifiers cope better with this harder version of the task. Try running :ref:`example_model_selection_grid_search_text_feature_extraction.py` with and without diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 2794948ce93d5..5f93333e585cf 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -469,7 +469,7 @@ def test_precision_recall_f1_score_multiclass_pos_label_none(): # compute scores with default labels introspection p, r, f, s = precision_recall_fscore_support(y_true, y_pred, pos_label=None, - average='binary') + average='macro') def test_zero_precision_recall(): From 1267f6da156bc506954100bca29107bb34475da5 Mon Sep 17 00:00:00 2001 From: hashcode55 Date: Sun, 19 Jun 2016 14:54:18 +0530 Subject: [PATCH 07/16] changed binomialNB to multinomialNB --- doc/datasets/twenty_newsgroups.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/datasets/twenty_newsgroups.rst b/doc/datasets/twenty_newsgroups.rst index 55fe227682190..2850b244eb12b 100644 --- a/doc/datasets/twenty_newsgroups.rst +++ b/doc/datasets/twenty_newsgroups.rst @@ -193,12 +193,12 @@ It loses even more if we also strip this metadata from the training data: ... remove=('headers', 'footers', 'quotes'), ... categories=categories) >>> vectors = vectorizer.fit_transform(newsgroups_train.data) - >>> clf = BernoulliNB(alpha=.01) + >>> clf = MultinomialNB(alpha=.01) >>> clf.fit(vectors, newsgroups_train.target) >>> vectors_test = vectorizer.transform(newsgroups_test.data) >>> pred = clf.predict(vectors_test) >>> metrics.f1_score(newsgroups_test.target, pred, average='macro') - 0.65437545099490202 + 0.76995175184521725 Some other classifiers cope better with this harder version of the task. Try running :ref:`example_model_selection_grid_search_text_feature_extraction.py` with and without From f911bb6eea0ac2582cdaf4ff37326df456898edc Mon Sep 17 00:00:00 2001 From: Yoav Ram Date: Sun, 19 Jun 2016 15:46:13 +0300 Subject: [PATCH 08/16] emphasis on "higher return values are better..." (#6909) --- doc/modules/model_evaluation.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 3bcf1142e410b..798666bfb6901 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -48,8 +48,8 @@ Common cases: predefined values For the most common use cases, you can designate a scorer object with the ``scoring`` parameter; the table below shows all possible values. -All scorer objects follow the convention that higher return values are better -than lower return values. Thus the returns from mean_absolute_error +All scorer objects follow the convention that **higher return values are better +than lower return values**. Thus the returns from mean_absolute_error and mean_squared_error, which measure the distance between the model and the data, are negated. From 1534d0c28e50d158a1d9bde35b777e1af0c1ac8e Mon Sep 17 00:00:00 2001 From: Brandon Carter Date: Tue, 21 Jun 2016 00:39:47 -0400 Subject: [PATCH 09/16] fix typo in comment of hierarchical clustering (#6912) --- sklearn/cluster/hierarchical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/hierarchical.py b/sklearn/cluster/hierarchical.py index b799be3bd7ad7..75f6914a13171 100644 --- a/sklearn/cluster/hierarchical.py +++ b/sklearn/cluster/hierarchical.py @@ -741,7 +741,7 @@ def fit(self, X, y=None): labels = _hierarchical.hc_get_heads(parents, copy=False) # copy to avoid holding a reference on the original array labels = np.copy(labels[:n_samples]) - # Reasign cluster numbers + # Reassign cluster numbers self.labels_ = np.searchsorted(np.unique(labels), labels) return self From 3c34fb3450aead4a9540152b3c9f631df330c24f Mon Sep 17 00:00:00 2001 From: Yen Date: Tue, 21 Jun 2016 13:26:31 +0800 Subject: [PATCH 10/16] [MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types (#6846) --- doc/whats_new.rst | 13 ++- sklearn/cluster/_k_means.pyx | 102 ++++++++++++++------ sklearn/cluster/_k_means_elkan.pyx | 50 +++++----- sklearn/cluster/k_means_.py | 51 +++++----- sklearn/cluster/tests/test_k_means.py | 119 +++++++++++++++++------ sklearn/src/cblas/cblas_sdot.c | 132 ++++++++++++++++++++++++++ 6 files changed, 359 insertions(+), 108 deletions(-) create mode 100644 sklearn/src/cblas/cblas_sdot.c diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3d0dab18bb3ee..1c5621d7eb9cb 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -209,6 +209,11 @@ Enhancements (`#6697 `_) by `Raghav R V`_. + - :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works + with ``np.float32`` and ``np.float64`` input data without converting it. + This allows to reduce the memory consumption by using ``np.float32``. + (`#6846 `_) + By `Sebastian Säger`_ and `YenChen Lin`_. Bug fixes ......... @@ -1769,7 +1774,7 @@ List of contributors for release 0.15 by number of commits. * 4 Alexis Metaireau * 4 Ignacio Rossi * 4 Virgile Fritsch -* 4 Sebastian Saeger +* 4 Sebastian Säger * 4 Ilambharathi Kanniah * 4 sdenton4 * 4 Robert Layton @@ -4266,4 +4271,8 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Wenhua Yang: https://github.com/geekoala -.. _Arnaud Fouchet: https://github.com/afouchet \ No newline at end of file +.. _Arnaud Fouchet: https://github.com/afouchet + +.. _Sebastian Säger: https://github.com/ssaeger + +.. _YenChen Lin: https://github.com/yenchenlin diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index e2f558897c6df..925c4df46cddf 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -13,6 +13,7 @@ import numpy as np import scipy.sparse as sp cimport numpy as np cimport cython +from cython cimport floating from ..utils.extmath import norm from sklearn.utils.sparsefuncs_fast import assign_rows_csr @@ -23,6 +24,7 @@ ctypedef np.int32_t INT cdef extern from "cblas.h": double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY) + float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY) np.import_array() @@ -30,11 +32,11 @@ np.import_array() @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, - np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, +cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] x_squared_norms, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """Compute label assignment and inertia for a dense array Return the inertia (sum of squared distances to the centers). @@ -43,24 +45,39 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, unsigned int n_clusters = centers.shape[0] unsigned int n_features = centers.shape[1] unsigned int n_samples = X.shape[0] - unsigned int x_stride = X.strides[1] / sizeof(DOUBLE) - unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE) + unsigned int x_stride + unsigned int center_stride unsigned int sample_idx, center_idx, feature_idx unsigned int store_distances = 0 unsigned int k + np.ndarray[floating, ndim=1] center_squared_norms + # the following variables are always double cause make them floating + # does not save any memory, but makes the code much bigger DOUBLE inertia = 0.0 DOUBLE min_dist DOUBLE dist - np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros( - n_clusters, dtype=np.float64) + + if floating is float: + center_squared_norms = np.zeros(n_clusters, dtype=np.float32) + x_stride = X.strides[1] / sizeof(float) + center_stride = centers.strides[1] / sizeof(float) + else: + center_squared_norms = np.zeros(n_clusters, dtype=np.float64) + x_stride = X.strides[1] / sizeof(DOUBLE) + center_stride = centers.strides[1] / sizeof(DOUBLE) if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = ddot( - n_features, ¢ers[center_idx, 0], center_stride, - ¢ers[center_idx, 0], center_stride) + if floating is float: + center_squared_norms[center_idx] = sdot( + n_features, ¢ers[center_idx, 0], center_stride, + ¢ers[center_idx, 0], center_stride) + else: + center_squared_norms[center_idx] = ddot( + n_features, ¢ers[center_idx, 0], center_stride, + ¢ers[center_idx, 0], center_stride) for sample_idx in range(n_samples): min_dist = -1 @@ -68,8 +85,12 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, dist = 0.0 # hardcoded: minimize euclidean distance to cluster center: # ||a - b||^2 = ||a||^2 + ||b||^2 -2 - dist += ddot(n_features, &X[sample_idx, 0], x_stride, - ¢ers[center_idx, 0], center_stride) + if floating is float: + dist += sdot(n_features, &X[sample_idx, 0], x_stride, + ¢ers[center_idx, 0], center_stride) + else: + dist += ddot(n_features, &X[sample_idx, 0], x_stride, + ¢ers[center_idx, 0], center_stride) dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] @@ -88,15 +109,15 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, @cython.wraparound(False) @cython.cdivision(True) cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """Compute label assignment and inertia for a CSR input Return the inertia (sum of squared distances to the centers). """ cdef: - np.ndarray[DOUBLE, ndim=1] X_data = X.data + np.ndarray[floating, ndim=1] X_data = X.data np.ndarray[INT, ndim=1] X_indices = X.indices np.ndarray[INT, ndim=1] X_indptr = X.indptr unsigned int n_clusters = centers.shape[0] @@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, unsigned int store_distances = 0 unsigned int sample_idx, center_idx, feature_idx unsigned int k + np.ndarray[floating, ndim=1] center_squared_norms + # the following variables are always double cause make them floating + # does not save any memory, but makes the code much bigger DOUBLE inertia = 0.0 DOUBLE min_dist DOUBLE dist - np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros( - n_clusters, dtype=np.float64) + + if floating is float: + center_squared_norms = np.zeros(n_clusters, dtype=np.float32) + else: + center_squared_norms = np.zeros(n_clusters, dtype=np.float64) if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = ddot( - n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + if floating is float: + center_squared_norms[center_idx] = sdot( + n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + else: + center_squared_norms[center_idx] = ddot( + n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) for sample_idx in range(n_samples): min_dist = -1 @@ -143,17 +174,17 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.wraparound(False) @cython.cdivision(True) def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] counts, np.ndarray[INT, ndim=1] nearest_center, - np.ndarray[DOUBLE, ndim=1] old_center, + np.ndarray[floating, ndim=1] old_center, int compute_squared_diff): """Incremental update of the centers for sparse MiniBatchKMeans. Parameters ---------- - X: CSR matrix, dtype float64 + X: CSR matrix, dtype float The complete (pre allocated) training set as a CSR matrix. centers: array, shape (n_clusters, n_features) @@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, of the algorithm. """ cdef: - np.ndarray[DOUBLE, ndim=1] X_data = X.data + np.ndarray[floating, ndim=1] X_data = X.data np.ndarray[int, ndim=1] X_indices = X.indices np.ndarray[int, ndim=1] X_indptr = X.indptr unsigned int n_samples = X.shape[0] @@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, +def _centers_dense(np.ndarray[floating, ndim=2] X, np.ndarray[INT, ndim=1] labels, int n_clusters, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm Computation of cluster centers / means. @@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, n_samples = X.shape[0] n_features = X.shape[1] cdef int i, j, c - cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features)) + cdef np.ndarray[floating, ndim=2] centers + if floating is float: + centers = np.zeros((n_clusters, n_features), dtype=np.float32) + else: + centers = np.zeros((n_clusters, n_features), dtype=np.float64) + n_samples_in_cluster = bincount(labels, minlength=n_clusters) empty_clusters = np.where(n_samples_in_cluster == 0)[0] # maybe also relocate small clusters? @@ -303,7 +339,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, @cython.wraparound(False) @cython.cdivision(True) def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm Computation of cluster centers / means. @@ -329,12 +365,11 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, cdef int n_features = X.shape[1] cdef int curr_label - cdef np.ndarray[DOUBLE, ndim=1] data = X.data + cdef np.ndarray[floating, ndim=1] data = X.data cdef np.ndarray[int, ndim=1] indices = X.indices cdef np.ndarray[int, ndim=1] indptr = X.indptr - cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \ - np.zeros((n_clusters, n_features)) + cdef np.ndarray[floating, ndim=2, mode="c"] centers cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \ bincount(labels, minlength=n_clusters) @@ -342,6 +377,11 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, np.where(n_samples_in_cluster == 0)[0] cdef int n_empty_clusters = empty_clusters.shape[0] + if floating is float: + centers = np.zeros((n_clusters, n_features), dtype=np.float32) + else: + centers = np.zeros((n_clusters, n_features), dtype=np.float64) + # maybe also relocate small clusters? if n_empty_clusters > 0: diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 09c3852e55000..f662402feb850 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -10,6 +10,7 @@ import numpy as np cimport numpy as np cimport cython +from cython cimport floating from libc.math cimport sqrt @@ -18,8 +19,8 @@ from ._k_means import _centers_dense from ..utils.fixes import partition -cdef double euclidian_dist(double* a, double* b, int n_features) nogil: - cdef double result, tmp +cdef floating euclidian_dist(floating* a, floating* b, int n_features) nogil: + cdef floating result, tmp result = 0 cdef int i for i in range(n_features): @@ -29,8 +30,8 @@ cdef double euclidian_dist(double* a, double* b, int n_features) nogil: cdef update_labels_distances_inplace( - double* X, double* centers, double[:, :] center_half_distances, - int[:] labels, double[:, :] lower_bounds, double[:] upper_bounds, + floating* X, floating* centers, floating[:, :] center_half_distances, + int[:] labels, floating[:, :] lower_bounds, floating[:] upper_bounds, int n_samples, int n_features, int n_clusters): """ Calculate upper and lower bounds for each sample. @@ -81,9 +82,9 @@ cdef update_labels_distances_inplace( """ # assigns closest center to X # uses triangle inequality - cdef double* x - cdef double* c - cdef double d_c, dist + cdef floating* x + cdef floating* c + cdef floating d_c, dist cdef int c_x, j, sample for sample in range(n_samples): # assign first cluster center @@ -103,8 +104,8 @@ cdef update_labels_distances_inplace( upper_bounds[sample] = d_c -def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters, - np.ndarray[np.float64_t, ndim=2, mode='c'] init, +def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, + np.ndarray[floating, ndim=2, mode='c'] init, float tol=1e-4, int max_iter=30, verbose=False): """Run Elkan's k-means. @@ -128,22 +129,27 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters, Whether to be verbose. """ - #initialize - cdef np.ndarray[np.float64_t, ndim=2, mode='c'] centers_ = init - cdef double* centers_p = centers_.data - cdef double* X_p = X_.data - cdef double* x_p + if floating is float: + dtype = np.float32 + else: + dtype = np.float64 + + #initialize + cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init + cdef floating* centers_p = centers_.data + cdef floating* X_p = X_.data + cdef floating* x_p cdef Py_ssize_t n_samples = X_.shape[0] cdef Py_ssize_t n_features = X_.shape[1] cdef int point_index, center_index, label - cdef float upper_bound, distance - cdef double[:, :] center_half_distances = euclidean_distances(centers_) / 2. - cdef double[:, :] lower_bounds = np.zeros((n_samples, n_clusters)) - cdef double[:] distance_next_center + cdef floating upper_bound, distance + cdef floating[:, :] center_half_distances = euclidean_distances(centers_) / 2. + cdef floating[:, :] lower_bounds = np.zeros((n_samples, n_clusters), dtype=dtype) + cdef floating[:] distance_next_center labels_ = np.empty(n_samples, dtype=np.int32) cdef int[:] labels = labels_ - upper_bounds_ = np.empty(n_samples, dtype=np.float) - cdef double[:] upper_bounds = upper_bounds_ + upper_bounds_ = np.empty(n_samples, dtype=dtype) + cdef floating[:] upper_bounds = upper_bounds_ # Get the inital set of upper bounds and lower bounds for each sample. update_labels_distances_inplace(X_p, centers_p, center_half_distances, @@ -151,7 +157,7 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters, n_samples, n_features, n_clusters) cdef np.uint8_t[:] bounds_tight = np.ones(n_samples, dtype=np.uint8) cdef np.uint8_t[:] points_to_update = np.zeros(n_samples, dtype=np.uint8) - cdef np.ndarray[np.float64_t, ndim=2, mode='c'] new_centers + cdef np.ndarray[floating, ndim=2, mode='c'] new_centers if max_iter <= 0: raise ValueError('Number of iterations should be a positive number' @@ -226,7 +232,7 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters, # reassign centers centers_ = new_centers - centers_p = new_centers.data + centers_p = new_centers.data # update between-center distances center_half_distances = euclidean_distances(centers_) / 2. diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 34970ea5317b1..d1c5e5c34e0f4 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -77,7 +77,7 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): """ n_samples, n_features = X.shape - centers = np.empty((n_clusters, n_features)) + centers = np.empty((n_clusters, n_features), dtype=X.dtype) assert x_squared_norms is not None, 'x_squared_norms None in _k_init' @@ -305,7 +305,7 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', X -= X_mean if hasattr(init, '__array__'): - init = check_array(init, dtype=np.float64, copy=True) + init = check_array(init, dtype=X.dtype.type, copy=True) _validate_center_shape(X, n_clusters, init) init -= X_mean @@ -396,7 +396,7 @@ def _kmeans_single_elkan(X, n_clusters, max_iter=300, init='k-means++', print('Initialization complete') centers, labels, n_iter = k_means_elkan(X, n_clusters, centers, tol=tol, max_iter=max_iter, verbose=verbose) - inertia = np.sum((X - centers[labels]) ** 2) + inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) return labels, inertia, centers, n_iter @@ -478,7 +478,7 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', # Allocate memory to store the distances for each sample to its # closer center for reallocation in case of ties - distances = np.zeros(shape=(X.shape[0],), dtype=np.float64) + distances = np.zeros(shape=(X.shape[0],), dtype=X.dtype) # iterations for i in range(max_iter): @@ -586,13 +586,13 @@ def _labels_inertia(X, x_squared_norms, centers, Precomputed squared euclidean norm of each data point, to speed up computations. - centers: float64 array, shape (k, n_features) + centers: float array, shape (k, n_features) The cluster centers. precompute_distances : boolean, default: True Precompute distances (faster but takes more memory). - distances: float64 array, shape (n_samples,) + distances: float array, shape (n_samples,) Pre-allocated array to be filled in with each sample's distance to the closest center. @@ -609,7 +609,7 @@ def _labels_inertia(X, x_squared_norms, centers, # easily labels = -np.ones(n_samples, np.int32) if distances is None: - distances = np.zeros(shape=(0,), dtype=np.float64) + distances = np.zeros(shape=(0,), dtype=X.dtype) # distances will be changed in-place if sp.issparse(X): inertia = _k_means._assign_labels_csr( @@ -685,9 +685,12 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, seeds = random_state.permutation(n_samples)[:k] centers = X[seeds] elif hasattr(init, '__array__'): - centers = init + # ensure that the centers have the same dtype as X + # this is a requirement of fused types of cython + centers = np.array(init, dtype=X.dtype) elif callable(init): centers = init(X, k, random_state=random_state) + centers = np.asarray(centers, dtype=X.dtype) else: raise ValueError("the init parameter for the k-means should " "be 'k-means++' or 'random' or an ndarray, " @@ -834,7 +837,7 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, def _check_fit_data(self, X): """Verify that the number of samples given is larger than k""" - X = check_array(X, accept_sparse='csr', dtype=np.float64) + X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32]) if X.shape[0] < self.n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( X.shape[0], self.n_clusters)) @@ -864,11 +867,12 @@ def fit(self, X, y=None): self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ k_means( - X, n_clusters=self.n_clusters, init=self.init, n_init=self.n_init, - max_iter=self.max_iter, verbose=self.verbose, + X, n_clusters=self.n_clusters, init=self.init, + n_init=self.n_init, max_iter=self.max_iter, verbose=self.verbose, precompute_distances=self.precompute_distances, tol=self.tol, random_state=random_state, copy_x=self.copy_x, - n_jobs=self.n_jobs, algorithm=self.algorithm, return_n_iter=True) + n_jobs=self.n_jobs, algorithm=self.algorithm, + return_n_iter=True) return self def fit_predict(self, X, y=None): @@ -983,7 +987,7 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, The vector in which we keep track of the numbers of elements in a cluster. This array is MODIFIED IN PLACE - distances : array, dtype float64, shape (n_samples), optional + distances : array, dtype float, shape (n_samples), optional If not None, should be a pre-allocated array that will be used to store the distances of each sample to its closest center. May not be None when random_reassign is True. @@ -1084,7 +1088,9 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, counts[center_idx] += count # inplace rescale to compute mean of all points (old and new) - centers[center_idx] /= counts[center_idx] + # Note: numpy >= 1.10 does not support '/=' for the following + # expression for a mixture of int and float (see numpy issue #6464) + centers[center_idx] = centers[center_idx] / counts[center_idx] # update the squared diff if necessary if compute_squared_diff: @@ -1282,7 +1288,8 @@ def fit(self, X, y=None): Coordinates of the data points to cluster """ random_state = check_random_state(self.random_state) - X = check_array(X, accept_sparse="csr", order='C', dtype=np.float64) + X = check_array(X, accept_sparse="csr", order='C', + dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if n_samples < self.n_clusters: raise ValueError("Number of samples smaller than number " @@ -1290,7 +1297,7 @@ def fit(self, X, y=None): n_init = self.n_init if hasattr(self.init, '__array__'): - self.init = np.ascontiguousarray(self.init, dtype=np.float64) + self.init = np.ascontiguousarray(self.init, dtype=X.dtype) if n_init != 1: warnings.warn( 'Explicit initial center position passed: ' @@ -1307,14 +1314,14 @@ def fit(self, X, y=None): # using tol-based early stopping needs the allocation of a # dedicated before which can be expensive for high dim data: # hence we allocate it outside of the main loop - old_center_buffer = np.zeros(n_features, np.double) + old_center_buffer = np.zeros(n_features, dtype=X.dtype) else: tol = 0.0 # no need for the center buffer if tol-based early stopping is # disabled - old_center_buffer = np.zeros(0, np.double) + old_center_buffer = np.zeros(0, dtype=X.dtype) - distances = np.zeros(self.batch_size, dtype=np.float64) + distances = np.zeros(self.batch_size, dtype=X.dtype) n_batches = int(np.ceil(float(n_samples) / self.batch_size)) n_iter = int(self.max_iter * n_batches) @@ -1446,7 +1453,7 @@ def partial_fit(self, X, y=None): X = check_array(X, accept_sparse="csr") n_samples, n_features = X.shape if hasattr(self.init, '__array__'): - self.init = np.ascontiguousarray(self.init, dtype=np.float64) + self.init = np.ascontiguousarray(self.init, dtype=X.dtype) if n_samples == 0: return self @@ -1472,10 +1479,10 @@ def partial_fit(self, X, y=None): # reassignment too often, to allow for building up counts random_reassign = self.random_state_.randint( 10 * (1 + self.counts_.min())) == 0 - distances = np.zeros(X.shape[0], dtype=np.float64) + distances = np.zeros(X.shape[0], dtype=X.dtype) _mini_batch_step(X, x_squared_norms, self.cluster_centers_, - self.counts_, np.zeros(0, np.double), 0, + self.counts_, np.zeros(0, dtype=X.dtype), 0, random_reassign=random_reassign, distances=distances, random_state=self.random_state_, reassignment_ratio=self.reassignment_ratio, diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 0b79baeb9e46a..9ee3ef616bf0b 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -16,7 +16,6 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_warns from sklearn.utils.testing import if_safe_multiprocessing_with_blas -from sklearn.utils.testing import if_not_mac_os from sklearn.utils.testing import assert_raise_message @@ -64,7 +63,8 @@ def test_elkan_results(): for X in [X_normal, X_blobs]: km_full.fit(X) km_elkan.fit(X) - assert_array_almost_equal(km_elkan.cluster_centers_, km_full.cluster_centers_) + assert_array_almost_equal(km_elkan.cluster_centers_, + km_full.cluster_centers_) assert_array_equal(km_elkan.labels_, km_full.labels_) @@ -287,14 +287,18 @@ def test_k_means_explicit_init_shape(): msg = "does not match the number of features of the data" assert_raises_regex(ValueError, msg, km.fit, X) # for callable init - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:, :2], n_clusters=len(X)) + km = Class(n_init=1, + init=lambda X_, k, random_state: X_[:, :2], + n_clusters=len(X)) assert_raises_regex(ValueError, msg, km.fit, X) # mismatch of number of clusters msg = "does not match the number of clusters" km = Class(n_init=1, init=X[:2, :], n_clusters=3) assert_raises_regex(ValueError, msg, km.fit, X) # for callable init - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:2, :], n_clusters=3) + km = Class(n_init=1, + init=lambda X_, k, random_state: X_[:2, :], + n_clusters=3) assert_raises_regex(ValueError, msg, km.fit, X) @@ -640,33 +644,33 @@ def test_predict_minibatch_random_init_sparse_input(): assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_) -def test_input_dtypes(): +def test_int_input(): X_list = [[0, 0], [10, 10], [12, 9], [-1, 1], [2, 0], [8, 10]] - X_int = np.array(X_list, dtype=np.int32) - X_int_csr = sp.csr_matrix(X_int) - init_int = X_int[:2] - - fitted_models = [ - KMeans(n_clusters=2).fit(X_list), - KMeans(n_clusters=2).fit(X_int), - KMeans(n_clusters=2, init=init_int, n_init=1).fit(X_list), - KMeans(n_clusters=2, init=init_int, n_init=1).fit(X_int), - # mini batch kmeans is very unstable on such a small dataset hence - # we use many inits - MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_list), - MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int), - MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int_csr), - MiniBatchKMeans(n_clusters=2, batch_size=2, - init=init_int, n_init=1).fit(X_list), - MiniBatchKMeans(n_clusters=2, batch_size=2, - init=init_int, n_init=1).fit(X_int), - MiniBatchKMeans(n_clusters=2, batch_size=2, - init=init_int, n_init=1).fit(X_int_csr), - ] - expected_labels = [0, 1, 1, 0, 0, 1] - scores = np.array([v_measure_score(expected_labels, km.labels_) - for km in fitted_models]) - assert_array_equal(scores, np.ones(scores.shape[0])) + for dtype in [np.int32, np.int64]: + X_int = np.array(X_list, dtype=dtype) + X_int_csr = sp.csr_matrix(X_int) + init_int = X_int[:2] + + fitted_models = [ + KMeans(n_clusters=2).fit(X_int), + KMeans(n_clusters=2, init=init_int, n_init=1).fit(X_int), + # mini batch kmeans is very unstable on such a small dataset hence + # we use many inits + MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int), + MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int_csr), + MiniBatchKMeans(n_clusters=2, batch_size=2, + init=init_int, n_init=1).fit(X_int), + MiniBatchKMeans(n_clusters=2, batch_size=2, + init=init_int, n_init=1).fit(X_int_csr), + ] + + for km in fitted_models: + assert_equal(km.cluster_centers_.dtype, np.float64) + + expected_labels = [0, 1, 1, 0, 0, 1] + scores = np.array([v_measure_score(expected_labels, km.labels_) + for km in fitted_models]) + assert_array_equal(scores, np.ones(scores.shape[0])) def test_transform(): @@ -771,4 +775,57 @@ def test_x_squared_norms_init_centroids(): def test_max_iter_error(): km = KMeans(max_iter=-1) - assert_raise_message(ValueError, 'Number of iterations should be', km.fit, X) + assert_raise_message(ValueError, 'Number of iterations should be', + km.fit, X) + + +def test_float_precision(): + km = KMeans(n_init=1, random_state=30) + mb_km = MiniBatchKMeans(n_init=1, random_state=30) + + inertia = {} + X_new = {} + centers = {} + + for estimator in [km, mb_km]: + for is_sparse in [False, True]: + for dtype in [np.float64, np.float32]: + if is_sparse: + X_test = sp.csr_matrix(X_csr, dtype=dtype) + else: + X_test = dtype(X) + estimator.fit(X_test) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(estimator.cluster_centers_.dtype, dtype) + inertia[dtype] = estimator.inertia_ + X_new[dtype] = estimator.transform(X_test) + centers[dtype] = estimator.cluster_centers_ + # make sure predictions correspond to the correct label + assert_equal(estimator.predict(X_test[0]), estimator.labels_[0]) + if hasattr(estimator, 'partial_fit'): + estimator.partial_fit(X_test[0:3]) + # dtype of cluster centers has to stay the same after partial_fit + assert_equal(estimator.cluster_centers_.dtype, dtype) + + # compare arrays with low precision since the difference between + # 32 and 64 bit sometimes makes a difference up to the 4th decimal place + assert_array_almost_equal(inertia[np.float32], inertia[np.float64], + decimal=4) + assert_array_almost_equal(X_new[np.float32], X_new[np.float64], + decimal=4) + assert_array_almost_equal(centers[np.float32], centers[np.float64], + decimal=4) + + +def test_KMeans_init_centers(): + # This test is used to check KMeans won't mutate the user provided input array silently + # even if input data and init centers have the same type + X_small = np.array([[1.1, 1.1], [-7.5, -7.5], [-1.1, -1.1], [7.5, 7.5]]) + init_centers = np.array([[0.0, 0.0], [5.0, 5.0], [-5.0, -5.0]]) + for dtype in [np.int32, np.int64, np.float32, np.float64]: + X_test = dtype(X_small) + init_centers_test = dtype(init_centers) + assert_array_equal(init_centers, init_centers_test) + km = KMeans(init=init_centers_test, n_clusters=3) + km.fit(X_test) + assert_equal(False, np.may_share_memory(km.cluster_centers_, init_centers)) diff --git a/sklearn/src/cblas/cblas_sdot.c b/sklearn/src/cblas/cblas_sdot.c new file mode 100644 index 0000000000000..e385b4484adce --- /dev/null +++ b/sklearn/src/cblas/cblas_sdot.c @@ -0,0 +1,132 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +/* + * Include files + */ +#include "atlas_refmisc.h" + +float cblas_sdot +( + const int N, + const float * X, + const int INCX, + const float * Y, + const int INCY +) +{ +/* + * Purpose + * ======= + * + * ATL_srefdot returns the dot product x^T * y of two n-vectors x and y. + * + * Arguments + * ========= + * + * N (input) const int + * On entry, N specifies the length of the vector x. N must be + * at least zero. Unchanged on exit. + * + * X (input) const float * + * On entry, X points to the first entry to be accessed of an + * incremented array of size equal to or greater than + * ( 1 + ( n - 1 ) * abs( INCX ) ) * sizeof( float ), + * that contains the vector x. Unchanged on exit. + * + * INCX (input) const int + * On entry, INCX specifies the increment for the elements of X. + * INCX must not be zero. Unchanged on exit. + * + * Y (input) const float * + * On entry, Y points to the first entry to be accessed of an + * incremented array of size equal to or greater than + * ( 1 + ( n - 1 ) * abs( INCY ) ) * sizeof( float ), + * that contains the vector y. Unchanged on exit. + * + * INCY (input) const int + * On entry, INCY specifies the increment for the elements of Y. + * INCY must not be zero. Unchanged on exit. + * + * --------------------------------------------------------------------- + */ +/* + * .. Local Variables .. + */ + register float dot = ATL_sZERO, x0, x1, x2, x3, + y0, y1, y2, y3; + float * StX; + register int i; + int nu; + const int incX2 = 2 * INCX, incY2 = 2 * INCY, + incX3 = 3 * INCX, incY3 = 3 * INCY, + incX4 = 4 * INCX, incY4 = 4 * INCY; +/* .. + * .. Executable Statements .. + * + */ + if( N > 0 ) + { + if( ( nu = ( N >> 2 ) << 2 ) != 0 ) + { + StX = (float *)X + nu * INCX; + + do + { + x0 = (*X); y0 = (*Y); x1 = X[INCX ]; y1 = Y[INCY ]; + x2 = X[incX2]; y2 = Y[incY2]; x3 = X[incX3]; y3 = Y[incY3]; + dot += x0 * y0; dot += x1 * y1; dot += x2 * y2; dot += x3 * y3; + X += incX4; Y += incY4; + } while( X != StX ); + } + + for( i = N - nu; i != 0; i-- ) + { x0 = (*X); y0 = (*Y); dot += x0 * y0; X += INCX; Y += INCY; } + } + return( dot ); +/* + * End of ATL_srefdot + */ +} From 2accd0ce28a550724957023297169ea40d6a8a52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 21 Jun 2016 15:24:42 +0200 Subject: [PATCH 11/16] Fix sklearn.base.clone for all scipy.sparse formats (#6910) --- sklearn/base.py | 21 ++++++++++++++++----- sklearn/tests/test_base.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 8c3a9a8eba4da..f98f4bf7e8cc5 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -21,6 +21,18 @@ class ChangedBehaviorWarning(_ChangedBehaviorWarning): ############################################################################## +def _first_and_last_element(arr): + """Returns first and last element of numpy array or sparse matrix.""" + if isinstance(arr, np.ndarray) or hasattr(arr, 'data'): + # numpy array or sparse matrix with .data attribute + data = arr.data if sparse.issparse(arr) else arr + return data.flat[0], data.flat[-1] + else: + # Sparse matrices without .data attribute. Only dok_matrix at + # the time of writing, in this case indexing is fast + return arr[0, 0], arr[-1, -1] + + def clone(estimator, safe=True): """Constructs a new estimator with the same parameters. @@ -73,9 +85,8 @@ def clone(estimator, safe=True): equality_test = ( param1.shape == param2.shape and param1.dtype == param2.dtype - # We have to use '.flat' for 2D arrays - and param1.flat[0] == param2.flat[0] - and param1.flat[-1] == param2.flat[-1] + and (_first_and_last_element(param1) == + _first_and_last_element(param2)) ) else: equality_test = np.all(param1 == param2) @@ -92,8 +103,8 @@ def clone(estimator, safe=True): else: equality_test = ( param1.__class__ == param2.__class__ - and param1.data[0] == param2.data[0] - and param1.data[-1] == param2.data[-1] + and (_first_and_last_element(param1) == + _first_and_last_element(param2)) and param1.nnz == param2.nnz and param1.shape == param2.shape ) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 873808ff914af..6f4be0dcc8ab7 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -1,6 +1,8 @@ # Author: Gael Varoquaux # License: BSD 3 clause +import sys + import numpy as np import scipy.sparse as sp @@ -143,6 +145,24 @@ def test_clone_nan(): assert_true(clf.empty is clf2.empty) +def test_clone_sparse_matrices(): + sparse_matrix_classes = [ + getattr(sp, name) + for name in dir(sp) if name.endswith('_matrix')] + + PY26 = sys.version_info[:2] == (2, 6) + if PY26: + # sp.dok_matrix can not be deepcopied in Python 2.6 + sparse_matrix_classes.remove(sp.dok_matrix) + + for cls in sparse_matrix_classes: + sparse_matrix = cls(np.eye(5)) + clf = MyEstimator(empty=sparse_matrix) + clf_cloned = clone(clf) + assert_true(clf.empty.__class__ is clf_cloned.empty.__class__) + assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray()) + + def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator() From a08a1fdba33d8bba4a7c824454ccf20e64783dd8 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Tue, 21 Jun 2016 23:30:37 +1000 Subject: [PATCH 12/16] DOC If git is not installed, need to catch OSError Fixes #6860 --- doc/sphinxext/github_link.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/sphinxext/github_link.py b/doc/sphinxext/github_link.py index ba0dd434e6ff6..38d0486870456 100644 --- a/doc/sphinxext/github_link.py +++ b/doc/sphinxext/github_link.py @@ -11,7 +11,7 @@ def _get_git_revision(): try: revision = subprocess.check_output(REVISION_CMD.split()).strip() - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, OSError): print('Failed to execute git to get revision') return None return revision.decode('utf-8') From 943836ceffbf9894ee994c7737771103ec63c3ac Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Tue, 21 Jun 2016 23:40:51 +1000 Subject: [PATCH 13/16] DOC add what's new for clone fix --- doc/whats_new.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 1c5621d7eb9cb..3a3ddf932a828 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -279,6 +279,10 @@ Bug fixes (`#6817 `_). By `Tom Dupre la Tour`_. + - Fix a bug where some formats of ``scipy.sparse`` matrix, and estimators + with them as parameters, could not be passed to :func:`base.clone`. + By `Loic Eseve`_. + API changes summary ------------------- From 478614a41658d35b9010a0893d9aa3b4ef251c57 Mon Sep 17 00:00:00 2001 From: Ryan Yu Date: Wed, 22 Jun 2016 03:31:14 -0400 Subject: [PATCH 14/16] fix a typo in ridge.py (#6917) --- sklearn/linear_model/ridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index b3d14bd2e16b2..e49d3b40e8c57 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -248,7 +248,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto', (possibility to set `tol` and `max_iter`). - 'lsqr' uses the dedicated regularized least-squares routine - scipy.sparse.linalg.lsqr. It is the fatest but may not be available + scipy.sparse.linalg.lsqr. It is the fastest but may not be available in old scipy versions. It also uses an iterative procedure. - 'sag' uses a Stochastic Average Gradient descent. It also uses an From 41000d54963da613e85341f8040b2fb587c3b3fd Mon Sep 17 00:00:00 2001 From: John Moeller Date: Wed, 22 Jun 2016 02:10:41 -0600 Subject: [PATCH 15/16] pep8 --- sklearn/preprocessing/data.py | 2 +- sklearn/preprocessing/tests/test_data.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 74099dc51b153..d81c382fa78bd 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1584,7 +1584,7 @@ def transform(self, K, y=None, copy=True): K += self.K_fit_all_ return K - + @property def _pairwise(self): return True diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index e2e080562400d..f35fc274edc2e 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1373,11 +1373,12 @@ def test_center_kernel(): K_pred_centered2 = centerer.transform(K_pred) assert_array_almost_equal(K_pred_centered, K_pred_centered2) + def test_cv_pipeline_precomputed(): - """Cross-validate a regression on four coplanar points with the same + """Cross-validate a regression on four coplanar points with the same value. Use precomputed kernel to ensure Pipeline with KernelCenterer is treated as a _pairwise operation.""" - X = np.array([[3,0,0],[0,3,0],[0,0,3],[1,1,1]]) + X = np.array([[3, 0, 0], [0, 3, 0], [0, 0, 3], [1, 1, 1]]) y_true = np.ones((4,)) K = X.dot(X.T) kcent = KernelCenterer() @@ -1389,7 +1390,7 @@ def test_cv_pipeline_precomputed(): # test cross-validation, score should be almost perfect # NB: this test is pretty vacuous -- it's mainly to test integration # of Pipeline and KernelCenterer - y_pred = cross_val_predict(pipeline,K,y_true,cv=4) + y_pred = cross_val_predict(pipeline, K, y_true, cv=4) assert_array_almost_equal(y_true, y_pred) From 3dfb2825316694cf4fa4cae01200c1afae177099 Mon Sep 17 00:00:00 2001 From: Gael Varoquaux Date: Wed, 22 Jun 2016 15:39:53 +0200 Subject: [PATCH 16/16] TST: Speed up: cv=2 This is a smoke test. Hence there is no point having cv=4 --- sklearn/preprocessing/tests/test_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f35fc274edc2e..5d81a24358d0e 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -10,7 +10,6 @@ import numpy.linalg as la from scipy import sparse from distutils.version import LooseVersion -from sklearn.externals.six import u from sklearn.utils import gen_batches @@ -1390,7 +1389,7 @@ def test_cv_pipeline_precomputed(): # test cross-validation, score should be almost perfect # NB: this test is pretty vacuous -- it's mainly to test integration # of Pipeline and KernelCenterer - y_pred = cross_val_predict(pipeline, K, y_true, cv=4) + y_pred = cross_val_predict(pipeline, K, y_true, cv=2) assert_array_almost_equal(y_true, y_pred)