diff --git a/sklearn/datasets/samples_generator.py b/sklearn/datasets/samples_generator.py index 00f15c96446c1..7bab8f720c46e 100644 --- a/sklearn/datasets/samples_generator.py +++ b/sklearn/datasets/samples_generator.py @@ -161,7 +161,8 @@ def make_classification(n_samples=100, n_features=20, n_informative=2, raise ValueError("Number of informative, redundant and repeated " "features must sum to less than the number of total" " features") - if 2 ** n_informative < n_classes * n_clusters_per_class: + # Use log2 to avoid overflow errors + if n_informative < np.log2(n_classes * n_clusters_per_class): raise ValueError("n_classes * n_clusters_per_class must" " be smaller or equal 2 ** n_informative") if weights and len(weights) not in [n_classes, n_classes - 1]: diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index c5a0c48b16ed0..1e1f110d9c41b 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -84,7 +84,8 @@ def test_make_classification_informative_features(): (2, [1/4] * 4, 1), (2, [1/2] * 2, 2), (2, [3/4, 1/4], 2), - (10, [1/3] * 3, 10) + (10, [1/3] * 3, 10), + (np.int(64), [1], 1) ]: n_classes = len(weights) n_clusters = n_classes * n_clusters_per_class @@ -128,19 +129,19 @@ def test_make_classification_informative_features(): for cluster in range(len(unique_signs)): centroid = X[cluster_index == cluster].mean(axis=0) if hypercube: - assert_array_almost_equal(np.abs(centroid), - [class_sep] * n_informative, - decimal=0, + assert_array_almost_equal(np.abs(centroid) / class_sep, + np.ones(n_informative), + decimal=5, err_msg="Clusters are not " "centered on hypercube " "vertices") else: assert_raises(AssertionError, assert_array_almost_equal, - np.abs(centroid), - [class_sep] * n_informative, - decimal=0, - err_msg="Clusters should not be cenetered " + np.abs(centroid) / class_sep, + np.ones(n_informative), + decimal=5, + err_msg="Clusters should not be centered " "on hypercube vertices") assert_raises(ValueError, make, n_features=2, n_informative=2, n_classes=5,