@@ -89,17 +89,49 @@ def test_ovr_partial_fit():
89
89
assert_equal (len (ovr .estimators_ ), len (np .unique (y )))
90
90
assert_greater (np .mean (y == pred ), 0.65 )
91
91
92
- # Test when mini batches doesn't have all classes
92
+ # Test when classes are more than 2 in each pass
93
+ X = np .random .rand (14 , 2 )
94
+ y = [0 , 0 , 1 , 1 , 2 , 2 , 0 , 0 , 1 , 2 , 2 , 3 , 3 , 3 ]
93
95
ovr = OneVsRestClassifier (MultinomialNB ())
94
- ovr .partial_fit (iris .data [:60 ], iris .target [:60 ], np .unique (iris .target ))
95
- ovr .partial_fit (iris .data [60 :], iris .target [60 :])
96
- pred = ovr .predict (iris .data )
96
+ ovr .partial_fit (X [:7 ], y [:7 ], np .unique (y ))
97
+ ovr .partial_fit (X [7 :], y [7 :])
98
+ pred = ovr .predict (X )
99
+
100
+ ovr1 = OneVsRestClassifier (MultinomialNB ())
101
+ ovr1 .fit (X , y )
102
+ pred1 = ovr1 .predict (X )
103
+ assert_almost_equal (np .mean (y == pred ), np .mean (pred1 == y ))
104
+
105
+ # Test when mini batches have 2 classes in each
106
+ # pass.
107
+ temp = datasets .load_iris ()
108
+ X , y = temp .data , temp .target
109
+ ovr = OneVsRestClassifier (MultinomialNB ())
110
+ ovr .partial_fit (X [:60 ], y [:60 ], np .unique (y ))
111
+ ovr .partial_fit (X [60 :], y [60 :])
112
+ pred = ovr .predict (X )
97
113
ovr2 = OneVsRestClassifier (MultinomialNB ())
98
- pred2 = ovr2 .fit (iris . data , iris . target ).predict (iris . data )
114
+ pred2 = ovr2 .fit (X , y ).predict (X )
99
115
100
116
assert_almost_equal (pred , pred2 )
101
- assert_equal (len (ovr .estimators_ ), len (np .unique (iris .target )))
102
- assert_greater (np .mean (iris .target == pred ), 0.65 )
117
+ assert_equal (len (ovr .estimators_ ), len (np .unique (y )))
118
+ assert_greater (np .mean (y == pred ), 0.65 )
119
+
120
+ # Check when mini batch classes doesn't conain classes from all_classes
121
+ rnd = np .random .rand (10 , 2 )
122
+ ovr = OneVsRestClassifier (MultinomialNB ())
123
+ assert_raises (ValueError , ovr .partial_fit , rnd [:5 ], [0 , 1 , 2 , 3 , 4 ],
124
+ [0 , 1 , 2 , 3 ])
125
+
126
+ # Test when mini-batches have one class target
127
+ ovr = OneVsRestClassifier (MultinomialNB ())
128
+ ovr .partial_fit (X [:125 ], y [:125 ], np .unique (y ))
129
+ ovr .partial_fit (X [125 :], y [125 :])
130
+ pred = ovr .predict (X )
131
+
132
+ assert_almost_equal (pred , pred2 )
133
+ assert_equal (len (ovr .estimators_ ), len (np .unique (y )))
134
+ assert_greater (np .mean (y == pred ), 0.65 )
103
135
104
136
105
137
def test_ovr_ovo_regressor ():
0 commit comments