diff --git a/sklearn/utils/tests/test_class_weight.py b/sklearn/utils/tests/test_class_weight.py index a073eeafcfdc3..a225c92aa2029 100644 --- a/sklearn/utils/tests/test_class_weight.py +++ b/sklearn/utils/tests/test_class_weight.py @@ -117,6 +117,14 @@ def test_compute_class_weight_balanced_unordered(): assert_almost_equal(np.dot(cw, class_counts), y.shape[0]) assert_array_almost_equal(cw, [2., 1., 2. / 3]) +def test_class_weight_with_string_label(): + y = np.asarray(["A","A","A","B","B","C"]) + classes = np.unique(y) + class_weights = {c: 1.0 for c in classes} + class_weights['D'] = 1.0 # This should get a proper ValueError + cw = assert_raises(ValueError, compute_class_weight, class_weights, + classes, y) + return def test_compute_sample_weight(): # Test (and demo) compute_sample_weight.