From 1176d66146d0d24cb5a21a7320c9b64fd7ac6364 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 26 Jan 2019 20:58:31 +0100 Subject: [PATCH 01/23] EHN: add support for non numeric values in MissingIndicator --- sklearn/impute.py | 35 +++++++++++++++++++---------------- sklearn/tests/test_impute.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/sklearn/impute.py b/sklearn/impute.py index 349af27eeb91e..9b6ced1074e9e 100644 --- a/sklearn/impute.py +++ b/sklearn/impute.py @@ -533,6 +533,23 @@ def _get_missing_features_info(self, X): return imputer_mask, features_with_missing + def _validate_input(self, X): + if not is_scalar_nan(self.missing_values): + force_all_finite = True + else: + force_all_finite = "allow-nan" + X = check_array(X, accept_sparse=('csc', 'csr'), dtype=None, + force_all_finite=force_all_finite) + _check_inputs_dtype(X, self.missing_values) + if X.dtype.kind not in ("i", "u", "f", "O"): + raise ValueError("Missing indicator does not support data with " + "dtype {0}. Please provide either a numeric array" + " (with a floating point or integer dtype) or " + "categorical data represented either as an array " + "with integer dtype or an array of string values " + "with an object dtype.".format(X.dtype)) + return X + def fit(self, X, y=None): """Fit the transformer on X. @@ -547,14 +564,7 @@ def fit(self, X, y=None): self : object Returns self. """ - if not is_scalar_nan(self.missing_values): - force_all_finite = True - else: - force_all_finite = "allow-nan" - X = check_array(X, accept_sparse=('csc', 'csr'), - force_all_finite=force_all_finite) - _check_inputs_dtype(X, self.missing_values) - + X = self._validate_input(X) self._n_features = X.shape[1] if self.features not in ('missing-only', 'all'): @@ -588,14 +598,7 @@ def transform(self, X): """ check_is_fitted(self, "features_") - - if not is_scalar_nan(self.missing_values): - force_all_finite = True - else: - force_all_finite = "allow-nan" - X = check_array(X, accept_sparse=('csc', 'csr'), - force_all_finite=force_all_finite) - _check_inputs_dtype(X, self.missing_values) + X = self._validate_input(X) if X.shape[1] != self._n_features: raise ValueError("X has a different number of features " diff --git a/sklearn/tests/test_impute.py b/sklearn/tests/test_impute.py index 7131ac3ed0f5f..ba1903bd9e310 100644 --- a/sklearn/tests/test_impute.py +++ b/sklearn/tests/test_impute.py @@ -13,6 +13,7 @@ from sklearn.impute import MissingIndicator from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline +from sklearn.pipeline import make_union from sklearn.model_selection import GridSearchCV from sklearn import tree from sklearn.random_projection import sparse_random_matrix @@ -509,7 +510,10 @@ def test_imputation_copy(): "'features' has to be either 'missing-only' or 'all'"), (np.array([[-1, 1], [1, 2]]), np.array([[-1, 1], [1, 2]]), {'features': 'all', 'sparse': 'random'}, - "'sparse' has to be a boolean or 'auto'")] + "'sparse' has to be a boolean or 'auto'"), + (np.array([['a', 'b'], ['c', 'a']], dtype=str), + np.array([['a', 'b'], ['c', 'a']], dtype=str), + {}, "Missing indicator does not support data with dtype")] ) def test_missing_indicator_error(X_fit, X_trans, params, msg_err): indicator = MissingIndicator(missing_values=-1) @@ -614,6 +618,31 @@ def test_missing_indicator_sparse_param(arr_type, missing_values, assert isinstance(X_trans_mask, np.ndarray) +def test_missing_indicator_string(): + X = np.array([['a', 'b', 'c'], ['b', 'c', 'a']], dtype=object) + indicator = MissingIndicator(missing_values='a', features='all') + X_trans = indicator.fit_transform(X) + assert_array_equal(X_trans, np.array([[True, False, False], + [False, False, True]])) + + +@pytest.mark.parametrize( + "X, missing_values, X_trans_exp", + [(np.array([['a', 'b'], ['b', 'a']], dtype=object), 'a', + np.array([['b', 'b', True, False], ['b', 'b', False, True]], + dtype=object)), + (np.array([[np.nan, 1.], [1., np.nan]]), np.nan, + np.array([[1., 1., True, False], [1., 1., False, True]]))] +) +def test_missing_indicator_with_imputer(X, missing_values, X_trans_exp): + trans = make_union( + SimpleImputer(missing_values=missing_values, strategy='most_frequent'), + MissingIndicator(missing_values=missing_values) + ) + X_trans = trans.fit_transform(X) + assert_array_equal(X_trans, X_trans_exp) + + @pytest.mark.parametrize("imputer_constructor", [SimpleImputer]) @pytest.mark.parametrize( From 00d7cb520b64d80e08d81fe3f97c3752fd7cc45a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 26 Jan 2019 22:34:58 +0100 Subject: [PATCH 02/23] DOC: add whats new entry --- doc/whats_new/v0.20.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index e9a1f831a7c32..21a466e433b64 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -36,6 +36,15 @@ Changelog threaded when `n_jobs > 1` or `n_jobs = -1`. :issue:`13005` by :user:`Prabakaran Kumaresshan `. +:mod:`sklearn.impute` +..................... + +- |Fix| add support for non-numeric data in + :class:`sklearn.impute.MissingIndicator` which was not supported while + :class:`sklearn.impute.SimpleImputer` was supporting these inputs for some + imputation strategies. + :issue:`13046` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... From 21eff01b8106c1e80a34dd7afa865c4bfa292e2a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 26 Jan 2019 23:08:34 +0100 Subject: [PATCH 03/23] FIX/TST: estimator supporting string should not raise error with dtype object --- sklearn/utils/estimator_checks.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 69850ecc5f796..eb0214b492ec4 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -74,10 +74,10 @@ 'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression', 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] - ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MissingIndicator', 'MaxAbsScaler', 'MinMaxScaler', 'RobustScaler', 'StandardScaler', 'PowerTransformer', 'QuantileTransformer'] +SUPPORT_STRING = ['SimpleImputer', 'MissingIndicator'] def _yield_non_meta_checks(name, estimator): @@ -625,9 +625,14 @@ def check_dtype_object(name, estimator_orig): if "Unknown label type" not in str(e): raise - X[0, 0] = {'foo': 'bar'} - msg = "argument must be a string or a number" - assert_raises_regex(TypeError, msg, estimator.fit, X, y) + if name not in SUPPORT_STRING: + X[0, 0] = {'foo': 'bar'} + msg = "argument must be a string or a number" + assert_raises_regex(TypeError, msg, estimator.fit, X, y) + else: + # If the estimator support strings passed as object dtype, + # it should not raise an error + estimator.fit(X, y) def check_complex_data(name, estimator_orig): From 42d5b36db2bbccee5583f18dc9de156225ad612d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 28 Jan 2019 11:50:29 +0100 Subject: [PATCH 04/23] DOC: fix the comment on the reason of not raising an error --- sklearn/utils/estimator_checks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index eb0214b492ec4..77c557685aa13 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -630,8 +630,10 @@ def check_dtype_object(name, estimator_orig): msg = "argument must be a string or a number" assert_raises_regex(TypeError, msg, estimator.fit, X, y) else: - # If the estimator support strings passed as object dtype, - # it should not raise an error + # Estimators supporting string will not call np.asarray to convert the + # data to numeric and therefore, the error will not be raised. + # Checking for each element dtype in the input array will be costly. + # Refer to #11401 for full discussion. estimator.fit(X, y) From 507ed3aa51deb963eb44bf74d223f0d4b96275be Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 29 Jan 2019 11:07:22 +0100 Subject: [PATCH 05/23] FIX: change error message --- sklearn/impute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/impute.py b/sklearn/impute.py index 9b6ced1074e9e..e6a9278b47800 100644 --- a/sklearn/impute.py +++ b/sklearn/impute.py @@ -542,7 +542,7 @@ def _validate_input(self, X): force_all_finite=force_all_finite) _check_inputs_dtype(X, self.missing_values) if X.dtype.kind not in ("i", "u", "f", "O"): - raise ValueError("Missing indicator does not support data with " + raise ValueError("MissingIndicator does not support data with " "dtype {0}. Please provide either a numeric array" " (with a floating point or integer dtype) or " "categorical data represented either as an array " From eb4765c20916b031a4fc63ab0c8740337a5cd3ff Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 29 Jan 2019 11:28:18 +0100 Subject: [PATCH 06/23] TST: fix the error message match --- sklearn/tests/test_impute.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/tests/test_impute.py b/sklearn/tests/test_impute.py index ba1903bd9e310..f68dbb8e28b30 100644 --- a/sklearn/tests/test_impute.py +++ b/sklearn/tests/test_impute.py @@ -56,7 +56,6 @@ def _check_statistics(X, X_true, err_msg=err_msg.format(True)) assert_ae(X_trans, X_true, err_msg=err_msg.format(True)) - def test_imputation_shape(): # Verify the shapes of the imputed matrix for different strategies. X = np.random.randn(10, 2) @@ -513,7 +512,7 @@ def test_imputation_copy(): "'sparse' has to be a boolean or 'auto'"), (np.array([['a', 'b'], ['c', 'a']], dtype=str), np.array([['a', 'b'], ['c', 'a']], dtype=str), - {}, "Missing indicator does not support data with dtype")] + {}, "MissingIndicator does not support data with dtype")] ) def test_missing_indicator_error(X_fit, X_trans, params, msg_err): indicator = MissingIndicator(missing_values=-1) From d4e777833c69acb19d52b8afa5321c76eff89224 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 30 Jan 2019 17:19:23 +0100 Subject: [PATCH 07/23] Update doc/whats_new/v0.20.rst Co-Authored-By: glemaitre --- doc/whats_new/v0.20.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 21a466e433b64..e8ad73b476a6a 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -41,7 +41,7 @@ Changelog - |Fix| add support for non-numeric data in :class:`sklearn.impute.MissingIndicator` which was not supported while - :class:`sklearn.impute.SimpleImputer` was supporting these inputs for some + :class:`sklearn.impute.SimpleImputer` was supporting this for some imputation strategies. :issue:`13046` by :user:`Guillaume Lemaitre `. From 7fb7599414a6cbd5f8da2555c2c6e1814b63e12d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 30 Jan 2019 17:26:20 +0100 Subject: [PATCH 08/23] TST: additional test with mixed type string/nan/None --- sklearn/tests/test_impute.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_impute.py b/sklearn/tests/test_impute.py index f68dbb8e28b30..2f6a4aa4ec6fe 100644 --- a/sklearn/tests/test_impute.py +++ b/sklearn/tests/test_impute.py @@ -56,6 +56,7 @@ def _check_statistics(X, X_true, err_msg=err_msg.format(True)) assert_ae(X_trans, X_true, err_msg=err_msg.format(True)) + def test_imputation_shape(): # Verify the shapes of the imputed matrix for different strategies. X = np.random.randn(10, 2) @@ -631,7 +632,13 @@ def test_missing_indicator_string(): np.array([['b', 'b', True, False], ['b', 'b', False, True]], dtype=object)), (np.array([[np.nan, 1.], [1., np.nan]]), np.nan, - np.array([[1., 1., True, False], [1., 1., False, True]]))] + np.array([[1., 1., True, False], [1., 1., False, True]])), + (np.array([[np.nan, 'b'], ['b', np.nan]], dtype=object), np.nan, + np.array([['b', 'b', True, False], ['b', 'b', False, True]], + dtype=object)), + (np.array([[None, 'b'], ['b', None]], dtype=object), None, + np.array([['b', 'b', True, False], ['b', 'b', False, True]], + dtype=object))] ) def test_missing_indicator_with_imputer(X, missing_values, X_trans_exp): trans = make_union( From ec0ab844db75aaeae42115df55d6567d2c645322 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Sat, 26 Jan 2019 05:07:04 -0500 Subject: [PATCH 09/23] [MRG] Configure lgtm.yml for CPP (#13044) --- lgtm.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 lgtm.yml diff --git a/lgtm.yml b/lgtm.yml new file mode 100644 index 0000000000000..fba6c64f197be --- /dev/null +++ b/lgtm.yml @@ -0,0 +1,7 @@ +extraction: + cpp: + before_index: + - pip3 install numpy scipy Cython + index: + build_command: + - python3 setup.py build_ext -i From 3721a618a1663c1a9f827f09ccb31ca8dea2e2e5 Mon Sep 17 00:00:00 2001 From: Raf Baluyot <7478783+baluyotraf@users.noreply.github.com> Date: Sun, 27 Jan 2019 01:15:18 +0800 Subject: [PATCH 10/23] FIX float16 overflow on accumulator operations in StandardScaler (#13010) --- doc/whats_new/v0.21.rst | 4 +++ sklearn/preprocessing/tests/test_data.py | 25 ++++++++++++++ sklearn/utils/extmath.py | 42 ++++++++++++++++++++---- sklearn/utils/validation.py | 8 +++-- 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index f7a7b825a3280..3990203c93b4f 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -222,6 +222,10 @@ Support for Python 3.4 and below has been officially dropped. in the dense case. Also added a new parameter ``order`` which controls output order for further speed performances. :issue:`12251` by `Tom Dupre la Tour`_. +- |Fix| Fixed the calculation overflow when using a float16 dtype with + :class:`preprocessing.StandardScaler`. :issue:`13007` by + :user:`Raffaello Baluyot ` + :mod:`sklearn.tree` ................... - |Feature| Decision Trees can now be plotted with matplotlib using diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 1a5ad20d32ef4..b387379be2cee 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -450,6 +450,31 @@ def test_scaler_2d_arrays(): assert X_scaled is not X +def test_scaler_float16_overflow(): + # Test if the scaler will not overflow on float16 numpy arrays + rng = np.random.RandomState(0) + # float16 has a maximum of 65500.0. On the worst case 5 * 200000 is 100000 + # which is enough to overflow the data type + X = rng.uniform(5, 10, [200000, 1]).astype(np.float16) + + with np.errstate(over='raise'): + scaler = StandardScaler().fit(X) + X_scaled = scaler.transform(X) + + # Calculate the float64 equivalent to verify result + X_scaled_f64 = StandardScaler().fit_transform(X.astype(np.float64)) + + # Overflow calculations may cause -inf, inf, or nan. Since there is no nan + # input, all of the outputs should be finite. This may be redundant since a + # FloatingPointError exception will be thrown on overflow above. + assert np.all(np.isfinite(X_scaled)) + + # The normal distribution is very unlikely to go above 4. At 4.0-8.0 the + # float16 precision is 2^-8 which is around 0.004. Thus only 2 decimals are + # checked to account for precision differences. + assert_array_almost_equal(X_scaled, X_scaled_f64, decimal=2) + + def test_handle_zeros_in_scale(): s1 = np.array([0, 1, 2, 3]) s2 = _handle_zeros_in_scale(s1, copy=True) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index fef2c7aff7971..44c6a392d9c6c 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -658,6 +658,38 @@ def make_nonnegative(X, min_value=0): return X +# Use at least float64 for the accumulating functions to avoid precision issue +# see https://github.com/numpy/numpy/issues/9393. The float64 is also retained +# as it is in case the float overflows +def _safe_accumulator_op(op, x, *args, **kwargs): + """ + This function provides numpy accumulator functions with a float64 dtype + when used on a floating point input. This prevents accumulator overflow on + smaller floating point dtypes. + + Parameters + ---------- + op : function + A numpy accumulator function such as np.mean or np.sum + x : numpy array + A numpy array to apply the accumulator function + *args : positional arguments + Positional arguments passed to the accumulator function after the + input x + **kwargs : keyword arguments + Keyword arguments passed to the accumulator function + + Returns + ------- + result : The output of the accumulator function passed to this function + """ + if np.issubdtype(x.dtype, np.floating) and x.dtype.itemsize < 8: + result = op(x, *args, **kwargs, dtype=np.float64) + else: + result = op(x, *args, **kwargs) + return result + + def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): """Calculate mean update and a Youngs and Cramer variance update. @@ -708,12 +740,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): # new = the current increment # updated = the aggregated stats last_sum = last_mean * last_sample_count - if np.issubdtype(X.dtype, np.floating) and X.dtype.itemsize < 8: - # Use at least float64 for the accumulator to avoid precision issues; - # see https://github.com/numpy/numpy/issues/9393 - new_sum = np.nansum(X, axis=0, dtype=np.float64).astype(X.dtype) - else: - new_sum = np.nansum(X, axis=0) + new_sum = _safe_accumulator_op(np.nansum, X, axis=0) new_sample_count = np.sum(~np.isnan(X), axis=0) updated_sample_count = last_sample_count + new_sample_count @@ -723,7 +750,8 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): if last_variance is None: updated_variance = None else: - new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count + new_unnormalized_variance = ( + _safe_accumulator_op(np.nanvar, X, axis=0) * new_sample_count) last_unnormalized_variance = last_variance * last_sample_count with np.errstate(divide='ignore', invalid='ignore'): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index fc882e0719a8d..379aa738f7124 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -34,14 +34,18 @@ def _assert_all_finite(X, allow_nan=False): """Like assert_all_finite, but only for ndarray.""" + # validation is also imported in extmath + from .extmath import _safe_accumulator_op + if _get_config()['assume_finite']: return X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent - # false positives from overflow in sum method. + # false positives from overflow in sum method. The sum is also calculated + # safely to reduce dtype induced overflows. is_float = X.dtype.kind in 'fc' - if is_float and np.isfinite(X.sum()): + if is_float and (np.isfinite(_safe_accumulator_op(np.sum, X))): pass elif is_float: msg_err = "Input contains {} or a value too large for {!r}." From f38c8d1eba07335bb59c52e62570c7ca726d2d0a Mon Sep 17 00:00:00 2001 From: xhan <44006219+xhan7279@users.noreply.github.com> Date: Sat, 26 Jan 2019 19:17:58 -0500 Subject: [PATCH 11/23] TST Use random state to initialize MLPClassifier. (#12892) --- sklearn/neural_network/tests/test_mlp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 130ec6554440b..36d2bc5db3077 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -432,7 +432,8 @@ def test_predict_proba_binary(): X = X_digits_binary[:50] y = y_digits_binary[:50] - clf = MLPClassifier(hidden_layer_sizes=5) + clf = MLPClassifier(hidden_layer_sizes=5, activation='logistic', + random_state=1) with ignore_warnings(category=ConvergenceWarning): clf.fit(X, y) y_proba = clf.predict_proba(X) From 4078baa05eb00e5e71eb12f15b2af5cbf972ea3c Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 27 Jan 2019 16:16:42 +0800 Subject: [PATCH 12/23] API Deprecate externals.six (#12916) --- benchmarks/bench_plot_fastkmeans.py | 2 -- benchmarks/bench_plot_omp_lars.py | 2 -- benchmarks/bench_plot_svd.py | 2 -- doc/whats_new/v0.21.rst | 6 ++++++ sklearn/externals/six.py | 6 ++++++ 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_plot_fastkmeans.py b/benchmarks/bench_plot_fastkmeans.py index d40d211dd1846..8f6b69b618b4b 100644 --- a/benchmarks/bench_plot_fastkmeans.py +++ b/benchmarks/bench_plot_fastkmeans.py @@ -3,8 +3,6 @@ from collections import defaultdict from time import time -import six - import numpy as np from numpy import random as nr diff --git a/benchmarks/bench_plot_omp_lars.py b/benchmarks/bench_plot_omp_lars.py index a9cc87e9d22f8..7c0d6dbfb4758 100644 --- a/benchmarks/bench_plot_omp_lars.py +++ b/benchmarks/bench_plot_omp_lars.py @@ -9,8 +9,6 @@ import sys from time import time -import six - import numpy as np from sklearn.linear_model import lars_path, orthogonal_mp diff --git a/benchmarks/bench_plot_svd.py b/benchmarks/bench_plot_svd.py index 7f96696a33c51..746c0df989e90 100644 --- a/benchmarks/bench_plot_svd.py +++ b/benchmarks/bench_plot_svd.py @@ -7,8 +7,6 @@ import numpy as np from collections import defaultdict -import six - from scipy.linalg import svd from sklearn.utils.extmath import randomized_svd from sklearn.datasets.samples_generator import make_low_rank_matrix diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 3990203c93b4f..91e910e541b06 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -91,6 +91,12 @@ Support for Python 3.4 and below has been officially dropped. with the document and the caller functions. :issue:`6463` by :user:`movelikeriver `. +:mod:`sklearn.externals` +........................ + +- |API| Deprecated :mod:`externals.six` since we have dropped support for + Python 2.7. :issue:`12916` by :user:`Hanmin Qin `. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/externals/six.py b/sklearn/externals/six.py index 85898ec71275f..cb5a46751f446 100644 --- a/sklearn/externals/six.py +++ b/sklearn/externals/six.py @@ -24,6 +24,12 @@ import sys import types +import warnings +warnings.warn("The module is deprecated in version 0.21 and will be removed " + "in version 0.23 since we've dropped support for Python 2.7. " + "Please rely on the official version of six " + "(https://pypi.org/project/six/).", DeprecationWarning) + __author__ = "Benjamin Peterson " __version__ = "1.4.1" From a0a99d8ea4f1f4d33eeb7cb7bd64140abfd5e379 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 27 Jan 2019 17:27:20 +0800 Subject: [PATCH 13/23] DOC Remove outdated doc in KBinsDiscretizer (#13047) --- sklearn/preprocessing/_discretization.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index b57e03230f4f1..cba2ebb9bd5a2 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -27,12 +27,7 @@ class KBinsDiscretizer(BaseEstimator, TransformerMixin): Parameters ---------- n_bins : int or array-like, shape (n_features,) (default=5) - The number of bins to produce. The intervals for the bins are - determined by the minimum and maximum of the input data. - Raises ValueError if ``n_bins < 2``. - - If ``n_bins`` is an array, and there is an ignored feature at - index ``i``, ``n_bins[i]`` will be ignored. + The number of bins to produce. Raises ValueError if ``n_bins < 2``. encode : {'onehot', 'onehot-dense', 'ordinal'}, (default='onehot') Method used to encode the transformed result. @@ -62,8 +57,7 @@ class KBinsDiscretizer(BaseEstimator, TransformerMixin): Attributes ---------- n_bins_ : int array, shape (n_features,) - Number of bins per feature. An ignored feature at index ``i`` - will have ``n_bins_[i] == 0``. + Number of bins per feature. bin_edges_ : array of arrays, shape (n_features, ) The edges of each bin. Contain arrays of varying shapes ``(n_bins_, )`` From e92eb1f9325230e166200e0744e340484d00fa79 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 27 Jan 2019 18:05:10 +0800 Subject: [PATCH 14/23] DOC Remove outdated doc in KBinsDiscretizer See #13074 --- sklearn/preprocessing/_discretization.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index cba2ebb9bd5a2..8c1047308b107 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -191,8 +191,6 @@ def fit(self, X, y=None): def _validate_n_bins(self, n_features): """Returns n_bins_, the number of bins per feature. - - Also ensures that ignored bins are zero. """ orig_bins = self.n_bins if isinstance(orig_bins, numbers.Number): From dd73bab06319f33b49d0f94a23b6ec3d34813a8b Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Mon, 28 Jan 2019 09:57:38 +0800 Subject: [PATCH 15/23] EXA Improve example plot_svm_anova.py (#11731) --- examples/svm/plot_svm_anova.py | 36 +++++++++++++++------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/examples/svm/plot_svm_anova.py b/examples/svm/plot_svm_anova.py index 08f9fddf71db5..84b6056991a84 100644 --- a/examples/svm/plot_svm_anova.py +++ b/examples/svm/plot_svm_anova.py @@ -4,37 +4,35 @@ ================================================= This example shows how to perform univariate feature selection before running a -SVC (support vector classifier) to improve the classification scores. +SVC (support vector classifier) to improve the classification scores. We use +the iris dataset (4 features) and add 36 non-informative features. We can find +that our model achieves best performance when we select around 10% of features. """ print(__doc__) import numpy as np import matplotlib.pyplot as plt -from sklearn.datasets import load_digits +from sklearn.datasets import load_iris from sklearn.feature_selection import SelectPercentile, chi2 from sklearn.model_selection import cross_val_score from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # ############################################################################# # Import some data to play with -X, y = load_digits(return_X_y=True) -# Throw away data, to be in the curse of dimension settings -X = X[:200] -y = y[:200] -n_samples = len(y) -X = X.reshape((n_samples, -1)) -# add 200 non-informative features -X = np.hstack((X, 2 * np.random.random((n_samples, 200)))) +X, y = load_iris(return_X_y=True) +# Add non-informative features +np.random.seed(0) +X = np.hstack((X, 2 * np.random.random((X.shape[0], 36)))) # ############################################################################# -# Create a feature-selection transform and an instance of SVM that we +# Create a feature-selection transform, a scaler and an instance of SVM that we # combine together to have an full-blown estimator - -transform = SelectPercentile(chi2) - -clf = Pipeline([('anova', transform), ('svc', SVC(gamma="auto"))]) +clf = Pipeline([('anova', SelectPercentile(chi2)), + ('scaler', StandardScaler()), + ('svc', SVC(gamma="auto"))]) # ############################################################################# # Plot the cross-validation score as a function of percentile of features @@ -44,17 +42,15 @@ for percentile in percentiles: clf.set_params(anova__percentile=percentile) - # Compute cross-validation score using 1 CPU - this_scores = cross_val_score(clf, X, y, cv=5, n_jobs=1) + this_scores = cross_val_score(clf, X, y, cv=5) score_means.append(this_scores.mean()) score_stds.append(this_scores.std()) plt.errorbar(percentiles, score_means, np.array(score_stds)) - plt.title( 'Performance of the SVM-Anova varying the percentile of features selected') +plt.xticks(np.linspace(0, 100, 11, endpoint=True)) plt.xlabel('Percentile') -plt.ylabel('Prediction rate') - +plt.ylabel('Accuracy Score') plt.axis('tight') plt.show() From de9630fdb572867615d1307b7b2e5b2bb7388b98 Mon Sep 17 00:00:00 2001 From: Vishaal Kapoor <40836875+vishaalkapoor@users.noreply.github.com> Date: Mon, 28 Jan 2019 22:50:41 -0800 Subject: [PATCH 16/23] DOC Correct TF-IDF formula in TfidfTransformer comments. (#13054) --- doc/modules/feature_extraction.rst | 29 +++++++++++++++-------------- sklearn/feature_extraction/text.py | 23 ++++++++++++----------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/doc/modules/feature_extraction.rst b/doc/modules/feature_extraction.rst index 2f506dcf7be07..9dfcaa4549f08 100644 --- a/doc/modules/feature_extraction.rst +++ b/doc/modules/feature_extraction.rst @@ -436,11 +436,12 @@ Using the ``TfidfTransformer``'s default settings, the term frequency, the number of times a term occurs in a given document, is multiplied with idf component, which is computed as -:math:`\text{idf}(t) = log{\frac{1 + n_d}{1+\text{df}(d,t)}} + 1`, +:math:`\text{idf}(t) = \log{\frac{1 + n}{1+\text{df}(t)}} + 1`, -where :math:`n_d` is the total number of documents, and :math:`\text{df}(d,t)` -is the number of documents that contain term :math:`t`. The resulting tf-idf -vectors are then normalized by the Euclidean norm: +where :math:`n` is the total number of documents in the document set, and +:math:`\text{df}(t)` is the number of documents in the document set that +contain term :math:`t`. The resulting tf-idf vectors are then normalized by the +Euclidean norm: :math:`v_{norm} = \frac{v}{||v||_2} = \frac{v}{\sqrt{v{_1}^2 + v{_2}^2 + \dots + v{_n}^2}}`. @@ -455,14 +456,14 @@ computed in scikit-learn's :class:`TfidfTransformer` and :class:`TfidfVectorizer` differ slightly from the standard textbook notation that defines the idf as -:math:`\text{idf}(t) = log{\frac{n_d}{1+\text{df}(d,t)}}.` +:math:`\text{idf}(t) = \log{\frac{n}{1+\text{df}(t)}}.` In the :class:`TfidfTransformer` and :class:`TfidfVectorizer` with ``smooth_idf=False``, the "1" count is added to the idf instead of the idf's denominator: -:math:`\text{idf}(t) = log{\frac{n_d}{\text{df}(d,t)}} + 1` +:math:`\text{idf}(t) = \log{\frac{n}{\text{df}(t)}} + 1` This normalization is implemented by the :class:`TfidfTransformer` class:: @@ -509,21 +510,21 @@ v{_2}^2 + \dots + v{_n}^2}}` For example, we can compute the tf-idf of the first term in the first document in the `counts` array as follows: -:math:`n_{d} = 6` +:math:`n = 6` -:math:`\text{df}(d, t)_{\text{term1}} = 6` +:math:`\text{df}(t)_{\text{term1}} = 6` -:math:`\text{idf}(d, t)_{\text{term1}} = -log \frac{n_d}{\text{df}(d, t)} + 1 = log(1)+1 = 1` +:math:`\text{idf}(t)_{\text{term1}} = +\log \frac{n}{\text{df}(t)} + 1 = \log(1)+1 = 1` :math:`\text{tf-idf}_{\text{term1}} = \text{tf} \times \text{idf} = 3 \times 1 = 3` Now, if we repeat this computation for the remaining 2 terms in the document, we get -:math:`\text{tf-idf}_{\text{term2}} = 0 \times (log(6/1)+1) = 0` +:math:`\text{tf-idf}_{\text{term2}} = 0 \times (\log(6/1)+1) = 0` -:math:`\text{tf-idf}_{\text{term3}} = 1 \times (log(6/2)+1) \approx 2.0986` +:math:`\text{tf-idf}_{\text{term3}} = 1 \times (\log(6/2)+1) \approx 2.0986` and the vector of raw tf-idfs: @@ -540,12 +541,12 @@ Furthermore, the default parameter ``smooth_idf=True`` adds "1" to the numerator and denominator as if an extra document was seen containing every term in the collection exactly once, which prevents zero divisions: -:math:`\text{idf}(t) = log{\frac{1 + n_d}{1+\text{df}(d,t)}} + 1` +:math:`\text{idf}(t) = \log{\frac{1 + n}{1+\text{df}(t)}} + 1` Using this modification, the tf-idf of the third term in document 1 changes to 1.8473: -:math:`\text{tf-idf}_{\text{term3}} = 1 \times log(7/3)+1 \approx 1.8473` +:math:`\text{tf-idf}_{\text{term3}} = 1 \times \log(7/3)+1 \approx 1.8473` And the L2-normalized tf-idf changes to diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index d705a060e7588..d06e4c7fd483e 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1146,17 +1146,18 @@ class TfidfTransformer(BaseEstimator, TransformerMixin): informative than features that occur in a small fraction of the training corpus. - The formula that is used to compute the tf-idf of term t is - tf-idf(d, t) = tf(t) * idf(d, t), and the idf is computed as - idf(d, t) = log [ n / df(d, t) ] + 1 (if ``smooth_idf=False``), - where n is the total number of documents and df(d, t) is the - document frequency; the document frequency is the number of documents d - that contain term t. The effect of adding "1" to the idf in the equation - above is that terms with zero idf, i.e., terms that occur in all documents - in a training set, will not be entirely ignored. - (Note that the idf formula above differs from the standard - textbook notation that defines the idf as - idf(d, t) = log [ n / (df(d, t) + 1) ]). + The formula that is used to compute the tf-idf for a term t of a document d + in a document set is tf-idf(t, d) = tf(t, d) * idf(t), and the idf is + computed as idf(t) = log [ n / df(t) ] + 1 (if ``smooth_idf=False``), where + n is the total number of documents in the document set and df(t) is the + document frequency of t; the document frequency is the number of documents + in the document set that contain the term t. The effect of adding "1" to + the idf in the equation above is that terms with zero idf, i.e., terms + that occur in all documents in a training set, will not be entirely + ignored. + (Note that the idf formula above differs from the standard textbook + notation that defines the idf as + idf(t) = log [ n / (df(t) + 1) ]). If ``smooth_idf=True`` (the default), the constant "1" is added to the numerator and denominator of the idf as if an extra document was seen From 1d8c5345a35f4c26ad2c46b6024873a104825b70 Mon Sep 17 00:00:00 2001 From: Gabriel Vacaliuc Date: Tue, 29 Jan 2019 18:06:05 -0600 Subject: [PATCH 17/23] FIX an issue w/ large sparse matrix indices in CountVectorizer (#11295) --- doc/whats_new/v0.20.rst | 8 +++++ sklearn/feature_extraction/tests/test_text.py | 32 ++++++++++++++++++- sklearn/feature_extraction/text.py | 13 ++++---- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index e8ad73b476a6a..3483b173dcb16 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -60,6 +60,14 @@ Changelog combination with ``handle_unknown='ignore'``. :issue:`12881` by `Joris Van den Bossche`_. +:mod:`sklearn.feature_extraction.text` +...................................... + +- |Fix| Fixed a bug in :class:`feature_extraction.text.CountVectorizer` which + would result in the sparse feature matrix having conflicting `indptr` and + `indices` precisions under very large vocabularies. :issue:`11295` by + :user:`Gabriel Vacaliuc `. + .. _changes_0_20_2: Version 0.20.2 diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 004f771126724..738ef67c60897 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -36,7 +36,8 @@ assert_warns_message, assert_raise_message, clean_warning_registry, ignore_warnings, SkipTest, assert_raises, assert_no_warnings, - fails_if_pypy, assert_allclose_dense_sparse) + fails_if_pypy, assert_allclose_dense_sparse, + skip_if_32bit) from collections import defaultdict from functools import partial import pickle @@ -1144,6 +1145,35 @@ def test_vectorizer_stop_words_inconsistent(): ['hello world']) +@skip_if_32bit +def test_countvectorizer_sort_features_64bit_sparse_indices(): + """ + Check that CountVectorizer._sort_features preserves the dtype of its sparse + feature matrix. + + This test is skipped on 32bit platforms, see: + https://github.com/scikit-learn/scikit-learn/pull/11295 + for more details. + """ + + X = sparse.csr_matrix((5, 5), dtype=np.int64) + + # force indices and indptr to int64. + INDICES_DTYPE = np.int64 + X.indices = X.indices.astype(INDICES_DTYPE) + X.indptr = X.indptr.astype(INDICES_DTYPE) + + vocabulary = { + "scikit-learn": 0, + "is": 1, + "great!": 2 + } + + Xs = CountVectorizer()._sort_features(X, vocabulary) + + assert INDICES_DTYPE == Xs.indices.dtype + + @fails_if_pypy @pytest.mark.parametrize('Estimator', [CountVectorizer, TfidfVectorizer, HashingVectorizer]) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index d06e4c7fd483e..788ba0596e8cc 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -31,6 +31,7 @@ from .stop_words import ENGLISH_STOP_WORDS from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES from ..utils.fixes import sp_version +from ..utils import _IS_32BIT __all__ = ['HashingVectorizer', @@ -871,7 +872,7 @@ def _sort_features(self, X, vocabulary): Returns a reordered matrix and modifies the vocabulary in place """ sorted_features = sorted(vocabulary.items()) - map_index = np.empty(len(sorted_features), dtype=np.int32) + map_index = np.empty(len(sorted_features), dtype=X.indices.dtype) for new_val, (term, old_val) in enumerate(sorted_features): vocabulary[term] = new_val map_index[old_val] = new_val @@ -961,14 +962,12 @@ def _count_vocab(self, raw_documents, fixed_vocab): " contain stop words") if indptr[-1] > 2147483648: # = 2**31 - 1 - if sp_version >= (0, 14): - indices_dtype = np.int64 - else: + if _IS_32BIT: raise ValueError(('sparse CSR array has {} non-zero ' 'elements and requires 64 bit indexing, ' - ' which is unsupported with scipy {}. ' - 'Please upgrade to scipy >=0.14') - .format(indptr[-1], '.'.join(sp_version))) + 'which is unsupported with 32 bit Python.') + .format(indptr[-1])) + indices_dtype = np.int64 else: indices_dtype = np.int32 From f4832474c32fb920f026c4af58adf3fc6f62e5cf Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Wed, 30 Jan 2019 08:08:41 +0800 Subject: [PATCH 18/23] DOC More details about the attributes in MinMaxScaler (#13029) --- sklearn/preprocessing/data.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 88a72946ab6b2..d8a7d58b26e3b 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -210,6 +210,11 @@ class MinMaxScaler(BaseEstimator, TransformerMixin): where min, max = feature_range. + The transformation is calculated as:: + + X_scaled = scale * X + min - X.min(axis=0) * scale + where scale = (max - min) / (X.max(axis=0) - X.min(axis=0)) + This transformation is often used as an alternative to zero mean, unit variance scaling. @@ -227,10 +232,12 @@ class MinMaxScaler(BaseEstimator, TransformerMixin): Attributes ---------- min_ : ndarray, shape (n_features,) - Per feature adjustment for minimum. + Per feature adjustment for minimum. Equivalent to + ``min - X.min(axis=0) * self.scale_`` scale_ : ndarray, shape (n_features,) - Per feature relative scaling of the data. + Per feature relative scaling of the data. Equivalent to + ``(max - min) / (X.max(axis=0) - X.min(axis=0))`` .. versionadded:: 0.17 *scale_* attribute. @@ -409,12 +416,17 @@ def minmax_scale(X, feature_range=(0, 1), axis=0, copy=True): that it is in the given range on the training set, i.e. between zero and one. - The transformation is given by:: + The transformation is given by (when ``axis=0``):: X_std = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) X_scaled = X_std * (max - min) + min where min, max = feature_range. + + The transformation is calculated as (when ``axis=0``):: + + X_scaled = scale * X + min - X.min(axis=0) * scale + where scale = (max - min) / (X.max(axis=0) - X.min(axis=0)) This transformation is often used as an alternative to zero mean, unit variance scaling. From 42ed1421cf1b3d2b8190ad2fb7e2c55334a75929 Mon Sep 17 00:00:00 2001 From: jeremiedbb <34657725+jeremiedbb@users.noreply.github.com> Date: Wed, 30 Jan 2019 01:10:49 +0100 Subject: [PATCH 19/23] DOC Clean up the advanced installation doc to remove python < 3.5 parts (#13064) --- doc/developers/advanced_installation.rst | 70 +++--------------------- doc/modules/computing.rst | 18 ------ 2 files changed, 7 insertions(+), 81 deletions(-) diff --git a/doc/developers/advanced_installation.rst b/doc/developers/advanced_installation.rst index b3647f83b92af..8f6f8496c3606 100644 --- a/doc/developers/advanced_installation.rst +++ b/doc/developers/advanced_installation.rst @@ -35,8 +35,8 @@ Building from source Scikit-learn requires: - Python (>= 3.5), -- NumPy (>= 1.8.2), -- SciPy (>= 0.13.3). +- NumPy (>= 1.11), +- SciPy (>= 0.17). .. note:: @@ -46,7 +46,7 @@ Scikit-learn requires: Building Scikit-learn also requires -- Cython >=0.23 +- Cython >=0.28.5 Running tests requires @@ -165,25 +165,16 @@ Windows To build scikit-learn on Windows you need a working C/C++ compiler in addition to numpy, scipy and setuptools. -Picking the right compiler depends on the version of Python (2 or 3) -and the architecture of the Python interpreter, 32-bit or 64-bit. -You can check the Python version by running the following in ``cmd`` or -``powershell`` console:: - - python --version - -and the architecture with:: +The building command depends on the architecture of the Python interpreter, +32-bit or 64-bit. You can check the architecture by running the following in +``cmd`` or ``powershell`` console:: python -c "import struct; print(struct.calcsize('P') * 8)" The above commands assume that you have the Python installation folder in your PATH environment variable. - -Python >= 3.5 -------------- - -For Python versions as of 3.5, you need `Build Tools for Visual Studio 2017 +You will need `Build Tools for Visual Studio 2017 `_. For 64-bit Python, configure the build environment with:: @@ -198,53 +189,6 @@ And build scikit-learn from this environment:: Replace ``x64`` by ``x86`` to build for 32-bit Python. -32-bit Python (<= 3.4) ----------------------- - -For 32-bit Python versions up to 3.4 use Microsoft Visual C++ Express 2010. - -Once installed you should be able to build scikit-learn without any -particular configuration by running the following command in the scikit-learn -folder:: - - python setup.py install - - -64-bit Python (<= 3.4) ----------------------- - -For 64-bit Python versions up to 3.4, you either need the full Visual Studio or -the free Windows SDKs that can be downloaded from the links below. - -The Windows SDKs include the MSVC compilers both for 32 and 64-bit -architectures. They come as a ``GRMSDKX_EN_DVD.iso`` file that can be mounted -as a new drive with a ``setup.exe`` installer in it. - -- For Python you need SDK **v7.1**: `MS Windows SDK for Windows 7 and .NET - Framework 4 - `_ - -Both SDKs can be installed in parallel on the same host. To use the Windows -SDKs, you need to setup the environment of a ``cmd`` console launched with the -following flags :: - - cmd /E:ON /V:ON /K - -Then configure the build environment with:: - - SET DISTUTILS_USE_SDK=1 - SET MSSdk=1 - "C:\Program Files\Microsoft SDKs\Windows\v7.1\Setup\WindowsSdkVer.exe" -q -version:v7.1 - "C:\Program Files\Microsoft SDKs\Windows\v7.1\Bin\SetEnv.cmd" /x64 /release - -Finally you can build scikit-learn in the same ``cmd`` console:: - - python setup.py install - -Replace ``/x64`` by ``/x86`` to build for 32-bit Python instead of 64-bit -Python. - - Building binary packages and installers --------------------------------------- diff --git a/doc/modules/computing.rst b/doc/modules/computing.rst index 8f6b32850bde2..d25a339c5b77a 100644 --- a/doc/modules/computing.rst +++ b/doc/modules/computing.rst @@ -430,24 +430,6 @@ and in this from Daniel Nouri which has some nice step by step install instructions for Debian / Ubuntu. -.. warning:: - - Multithreaded BLAS libraries sometimes conflict with Python's - ``multiprocessing`` module, which is used by e.g. ``GridSearchCV`` and - most other estimators that take an ``n_jobs`` argument (with the exception - of ``SGDClassifier``, ``SGDRegressor``, ``Perceptron``, - ``PassiveAggressiveClassifier`` and tree-based methods such as random - forests). This is true of Apple's Accelerate and OpenBLAS when built with - OpenMP support. - - Besides scikit-learn, NumPy and SciPy also use BLAS internally, as - explained earlier. - - If you experience hanging subprocesses with ``n_jobs>1`` or ``n_jobs=-1``, - make sure you have a single-threaded BLAS library, or set ``n_jobs=1``, - or upgrade to Python 3.4 which has a new version of ``multiprocessing`` - that should be immune to this problem. - .. _working_memory: Limiting Working Memory From 9445eb11acbc8794368f9bf812d534797409660a Mon Sep 17 00:00:00 2001 From: "Zijie (ZJ) Poh" <8103276+zjpoh@users.noreply.github.com> Date: Tue, 29 Jan 2019 16:12:21 -0800 Subject: [PATCH 20/23] API NMF and non_negative_factorization have inconsistent default init (#12989) --- doc/whats_new/v0.21.rst | 10 +++++++++ sklearn/decomposition/nmf.py | 27 ++++++++++++++++++++----- sklearn/decomposition/tests/test_nmf.py | 14 ++++++++----- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 91e910e541b06..e1e3d32a67fd4 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -58,6 +58,16 @@ Support for Python 3.4 and below has been officially dropped. :class:`datasets.svmlight_format` :issue:`10727` by :user:`Bryan K Woods `. +:mod:`sklearn.decomposition` +............................ + +- |API| The default value of the :code:`init` argument in + :func:`decomposition.non_negative_factorization` will change from + :code:`random` to :code:`None` in version 0.23 to make it consistent with + :class:`decomposition.NMF`. A FutureWarning is raised when + the default value is used. + :issue:`12988` by :user:`Zijie (ZJ) Poh `. + :mod:`sklearn.discriminant_analysis` .................................... diff --git a/sklearn/decomposition/nmf.py b/sklearn/decomposition/nmf.py index 0617a1797fcdc..b6586c493ed85 100644 --- a/sklearn/decomposition/nmf.py +++ b/sklearn/decomposition/nmf.py @@ -261,9 +261,11 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6, init : None | 'random' | 'nndsvd' | 'nndsvda' | 'nndsvdar' Method used to initialize the procedure. - Default: 'nndsvd' if n_components < n_features, otherwise 'random'. + Default: None. Valid options: + - None: 'nndsvd' if n_components < n_features, otherwise 'random'. + - 'random': non-negative random matrices, scaled with: sqrt(X.mean() / n_components) @@ -831,7 +833,7 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius', def non_negative_factorization(X, W=None, H=None, n_components=None, - init='random', update_H=True, solver='cd', + init='warn', update_H=True, solver='cd', beta_loss='frobenius', tol=1e-4, max_iter=200, alpha=0., l1_ratio=0., regularization=None, random_state=None, @@ -878,11 +880,17 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, Number of components, if n_components is not set all features are kept. - init : None | 'random' | 'nndsvd' | 'nndsvda' | 'nndsvdar' | 'custom' + init : None | 'random' | 'nndsvd' | 'nndsvda' | 'nndsvdar' | 'custom' Method used to initialize the procedure. Default: 'random'. + + The default value will change from 'random' to None in version 0.23 + to make it consistent with decomposition.NMF. + Valid options: + - None: 'nndsvd' if n_components < n_features, otherwise 'random'. + - 'random': non-negative random matrices, scaled with: sqrt(X.mean() / n_components) @@ -1009,6 +1017,13 @@ def non_negative_factorization(X, W=None, H=None, n_components=None, raise ValueError("Tolerance for stopping criteria must be " "positive; got (tol=%r)" % tol) + if init == "warn": + if n_components < n_features: + warnings.warn("The default value of init will change from " + "random to None in 0.23 to make it consistent " + "with decomposition.NMF.", FutureWarning) + init = "random" + # check W and H, or initialize them if init == 'custom' and update_H: _check_init(H, (n_components, n_features), "NMF (input H)") @@ -1087,11 +1102,13 @@ class NMF(BaseEstimator, TransformerMixin): Number of components, if n_components is not set all features are kept. - init : 'random' | 'nndsvd' | 'nndsvda' | 'nndsvdar' | 'custom' + init : None | 'random' | 'nndsvd' | 'nndsvda' | 'nndsvdar' | 'custom' Method used to initialize the procedure. - Default: 'nndsvd' if n_components < n_features, otherwise random. + Default: None. Valid options: + - None: 'nndsvd' if n_components < n_features, otherwise random. + - 'random': non-negative random matrices, scaled with: sqrt(X.mean() / n_components) diff --git a/sklearn/decomposition/tests/test_nmf.py b/sklearn/decomposition/tests/test_nmf.py index 49e8b46676aec..cc1f44296e03f 100644 --- a/sklearn/decomposition/tests/test_nmf.py +++ b/sklearn/decomposition/tests/test_nmf.py @@ -10,6 +10,7 @@ import pytest from sklearn.utils.testing import assert_raise_message, assert_no_warnings +from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_almost_equal @@ -213,13 +214,16 @@ def test_non_negative_factorization_checking(): A = np.ones((2, 2)) # Test parameters checking is public function nnmf = non_negative_factorization - assert_no_warnings(nnmf, A, A, A, np.int64(1)) + msg = ("The default value of init will change from " + "random to None in 0.23 to make it consistent " + "with decomposition.NMF.") + assert_warns_message(FutureWarning, msg, nnmf, A, A, A, np.int64(1)) msg = ("Number of components must be a positive integer; " "got (n_components=1.5)") - assert_raise_message(ValueError, msg, nnmf, A, A, A, 1.5) + assert_raise_message(ValueError, msg, nnmf, A, A, A, 1.5, 'random') msg = ("Number of components must be a positive integer; " "got (n_components='2')") - assert_raise_message(ValueError, msg, nnmf, A, A, A, '2') + assert_raise_message(ValueError, msg, nnmf, A, A, A, '2', 'random') msg = "Negative values in data passed to NMF (input H)" assert_raise_message(ValueError, msg, nnmf, A, A, -A, 2, 'custom') msg = "Negative values in data passed to NMF (input W)" @@ -380,8 +384,8 @@ def test_nmf_negative_beta_loss(): def _assert_nmf_no_nan(X, beta_loss): W, H, _ = non_negative_factorization( - X, n_components=n_components, solver='mu', beta_loss=beta_loss, - random_state=0, max_iter=1000) + X, init='random', n_components=n_components, solver='mu', + beta_loss=beta_loss, random_state=0, max_iter=1000) assert not np.any(np.isnan(W)) assert not np.any(np.isnan(H)) From a95b08e3562a960e53da85957cfa3cae82cebc81 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 30 Jan 2019 12:43:31 +0100 Subject: [PATCH 21/23] MAINT: pin flake8 to stable version (#13066) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 89cc103ec6301..c875abc0035fd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -67,7 +67,7 @@ jobs: - run: ./build_tools/circle/checkout_merge_commit.sh - run: name: dependencies - command: sudo pip install flake8 + command: sudo pip install flake8==3.6.0 - run: name: flake8 command: ./build_tools/circle/flake8_diff.sh From 193ed4ebb4c4de741f8c40edf3d89bf2c3319dc9 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Wed, 30 Jan 2019 22:41:37 +0800 Subject: [PATCH 22/23] EXA: fix xlabel and ylabel in plot_cv_digits.py (#13067) --- examples/exercises/plot_cv_digits.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/exercises/plot_cv_digits.py b/examples/exercises/plot_cv_digits.py index f51bcc7e0256e..27b26f13de54f 100644 --- a/examples/exercises/plot_cv_digits.py +++ b/examples/exercises/plot_cv_digits.py @@ -32,8 +32,7 @@ # Do the plotting import matplotlib.pyplot as plt -plt.figure(1, figsize=(4, 3)) -plt.clf() +plt.figure() plt.semilogx(C_s, scores) plt.semilogx(C_s, np.array(scores) + np.array(scores_std), 'b--') plt.semilogx(C_s, np.array(scores) - np.array(scores_std), 'b--') From 85fe5f6b89fc99e814bd018d5e9a451b1a389c51 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 30 Jan 2019 22:17:34 +0100 Subject: [PATCH 23/23] MAINT: remove flake8 pinning in circle ci (#13071) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c875abc0035fd..89cc103ec6301 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -67,7 +67,7 @@ jobs: - run: ./build_tools/circle/checkout_merge_commit.sh - run: name: dependencies - command: sudo pip install flake8==3.6.0 + command: sudo pip install flake8 - run: name: flake8 command: ./build_tools/circle/flake8_diff.sh