Skip to content

Commit 55e8810

Browse files
pprettGaelVaroquaux
authored andcommitted
add test for proper loss instantiation
1 parent ee838a2 commit 55e8810

File tree

1 file changed

+5
-20
lines changed

1 file changed

+5
-20
lines changed

sklearn/linear_model/tests/test_sgd.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,6 @@ def test_sgd_bad_penalty(self):
185185
"""Check whether expected ValueError on bad penalty"""
186186
self.factory(penalty='foobar', rho=0.85)
187187

188-
def test_sgd_losses(self):
189-
"""Check whether losses and hyperparameters are set properly"""
190-
clf = self.factory(loss='hinge')
191-
assert_true(isinstance(clf.loss_function, linear_model.Hinge))
192-
193-
clf = self.factory(loss='log')
194-
assert_true(isinstance(clf.loss_function, linear_model.Log))
195-
196-
clf = self.factory(loss='modified_huber')
197-
assert_true(isinstance(clf.loss_function, linear_model.ModifiedHuber))
198-
199188
@raises(ValueError)
200189
def test_sgd_bad_loss(self):
201190
"""Check whether expected ValueError on bad loss"""
@@ -562,15 +551,6 @@ def test_sgd_bad_penalty(self):
562551
"""Check whether expected ValueError on bad penalty"""
563552
self.factory(penalty='foobar', rho=0.85)
564553

565-
def test_sgd_losses(self):
566-
"""Check whether losses and hyperparameters are set properly"""
567-
clf = self.factory(loss='squared_loss')
568-
assert_true(isinstance(clf.loss_function, linear_model.SquaredLoss))
569-
570-
clf = self.factory(loss='huber', epsilon=0.5)
571-
assert_true(isinstance(clf.loss_function, linear_model.Huber))
572-
assert_equal(clf.epsilon, 0.5)
573-
574554
@raises(ValueError)
575555
def test_sgd_bad_loss(self):
576556
"""Check whether expected ValueError on bad loss"""
@@ -715,6 +695,11 @@ def test_partial_fit_equal_fit_optimal(self):
715695
def test_partial_fit_equal_fit_invscaling(self):
716696
self._test_partial_fit_equal_fit("invscaling")
717697

698+
def test_loss_function_epsilon(self):
699+
clf = self.factory(epsilon=0.9)
700+
clf.set_params(epsilon=0.1)
701+
assert clf.loss_functions['huber'][1] == 0.1
702+
718703

719704
class SparseSGDRegressorTestCase(DenseSGDRegressorTestCase):
720705
"""Run exactly the same tests using the sparse representation variant"""

0 commit comments

Comments
 (0)