From d179e2ec699746981ac9caada7b91f5835f228f9 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Thu, 26 Oct 2017 00:26:09 +0530 Subject: [PATCH 1/2] fixes #9633 correct comparison in GaussianNB for 'priors' --- sklearn/naive_bayes.py | 2 +- sklearn/tests/test_naive_bayes.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index 6aec725bd9802..ae01ccb62f238 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -374,7 +374,7 @@ def _partial_fit(self, X, y, classes=None, _refit=False, raise ValueError('Number of priors must match number of' ' classes.') # Check that the sum is 1 - if priors.sum() != 1.0: + if not np.isclose(priors.sum(), 1.0): raise ValueError('The sum of the priors should be 1.') # Check that the prior are non-negative if (priors < 0).any(): diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 97a119dca6ba1..3726c51c7fa52 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -114,6 +114,16 @@ def test_gnb_priors(): assert_array_almost_equal(clf.class_prior_, np.array([0.3, 0.7])) +def test_gnb_priors_sum_isclose(): + """Test whether the class prior sum is properly tested""" + X = np.array([[-1, -1], [-2, -1], [-3, -2], [-4, -5], [-5, -4], + [1, 1], [2, 1], [3, 2], [4, 4], [5, 5]]) + priors = np.array([0.08, 0.14, 0.03, 0.16, 0.11, 0.16, 0.07, 0.14, 0.11, 0.0]) + Y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + clf = GaussianNB(priors) + clf.fit(X, Y) + + def test_gnb_wrong_nb_priors(): """ Test whether an error is raised if the number of prior is different from the number of class""" From 183ae55f69721eef8fbfd4c7f850549564d67e18 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Thu, 26 Oct 2017 10:47:57 +0530 Subject: [PATCH 2/2] fix flake8 errors and address changes regarding comments --- sklearn/tests/test_naive_bayes.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 3726c51c7fa52..b2b1b63c98b19 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -115,12 +115,14 @@ def test_gnb_priors(): def test_gnb_priors_sum_isclose(): - """Test whether the class prior sum is properly tested""" + # test whether the class prior sum is properly tested""" X = np.array([[-1, -1], [-2, -1], [-3, -2], [-4, -5], [-5, -4], - [1, 1], [2, 1], [3, 2], [4, 4], [5, 5]]) - priors = np.array([0.08, 0.14, 0.03, 0.16, 0.11, 0.16, 0.07, 0.14, 0.11, 0.0]) + [1, 1], [2, 1], [3, 2], [4, 4], [5, 5]]) + priors = np.array([0.08, 0.14, 0.03, 0.16, 0.11, 0.16, 0.07, 0.14, + 0.11, 0.0]) Y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) clf = GaussianNB(priors) + # smoke test for issue #9633 clf.fit(X, Y)