Skip to content

Commit e0ee2f7

Browse files
committed
Merge pull request opencv#8116 from mrquorr:master
2 parents f46fa6e + d8425d8 commit e0ee2f7

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

modules/ml/include/opencv2/ml.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,17 @@ class CV_EXPORTS_W RTrees : public DTrees
12061206
*/
12071207
CV_WRAP virtual Mat getVarImportance() const = 0;
12081208

1209+
/** Returns the result of each individual tree in the forest.
1210+
In case the model is a regression problem, the method will return each of the trees'
1211+
results for each of the sample cases. If the model is a classifier, it will return
1212+
a Mat with samples + 1 rows, where the first row gives the class number and the
1213+
following rows return the votes each class had for each sample.
1214+
@param samples Array containg the samples for which votes will be calculated.
1215+
@param results Array where the result of the calculation will be written.
1216+
@param flags Flags for defining the type of RTrees.
1217+
*/
1218+
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
1219+
12091220
/** Creates the empty model.
12101221
Use StatModel::train to train the model, StatModel::train to create and train the model,
12111222
Algorithm::load to load the pre-trained model.

modules/ml/src/rtrees.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,60 @@ class DTreesImplForRTrees : public DTreesImpl
349349
}
350350
}
351351

352+
void getVotes( InputArray input, OutputArray output, int flags ) const
353+
{
354+
CV_Assert( !roots.empty() );
355+
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
356+
Mat samples = input.getMat(), results;
357+
int i, j, nsamples = samples.rows;
358+
359+
int predictType = flags & PREDICT_MASK;
360+
if( predictType == PREDICT_AUTO )
361+
{
362+
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
363+
PREDICT_SUM : PREDICT_MAX_VOTE;
364+
}
365+
366+
if( predictType == PREDICT_SUM )
367+
{
368+
output.create(nsamples, ntrees, CV_32F);
369+
results = output.getMat();
370+
for( i = 0; i < nsamples; i++ )
371+
{
372+
for( j = 0; j < ntrees; j++ )
373+
{
374+
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
375+
results.at<float> (i, j) = val;
376+
}
377+
}
378+
} else
379+
{
380+
vector<int> votes;
381+
output.create(nsamples+1, nclasses, CV_32S);
382+
results = output.getMat();
383+
384+
for ( j = 0; j < nclasses; j++)
385+
{
386+
results.at<int> (0, j) = classLabels[j];
387+
}
388+
389+
for( i = 0; i < nsamples; i++ )
390+
{
391+
votes.clear();
392+
for( j = 0; j < ntrees; j++ )
393+
{
394+
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
395+
votes.push_back(val);
396+
}
397+
398+
for ( j = 0; j < nclasses; j++)
399+
{
400+
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
401+
}
402+
}
403+
}
404+
}
405+
352406
RTreeParams rparams;
353407
double oobError;
354408
vector<float> varImportance;
@@ -401,6 +455,11 @@ class RTreesImpl : public RTrees
401455
impl.read(fn);
402456
}
403457

458+
void getVotes_( InputArray samples, OutputArray results, int flags ) const
459+
{
460+
impl.getVotes(samples, results, flags);
461+
}
462+
404463
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
405464
int getVarCount() const { return impl.getVarCount(); }
406465

@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
427486
return Algorithm::load<RTrees>(filepath, nodeName);
428487
}
429488

489+
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
490+
{
491+
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
492+
if(!this_)
493+
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
494+
return this_->getVotes_(input, output, flags);
495+
}
496+
430497
}}
431498

432499
// End of file.

modules/ml/test/test_mltests.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
172172
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
173173
}
174174

175+
TEST(ML_RTrees, getVotes)
176+
{
177+
int n = 12;
178+
int count, i;
179+
int label_size = 3;
180+
int predicted_class = 0;
181+
int max_votes = -1;
182+
int val;
183+
// RTrees for classification
184+
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
185+
186+
//data
187+
Mat data(n, 4, CV_32F);
188+
randu(data, 0, 10);
189+
190+
//labels
191+
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
192+
193+
rt->train(data, ml::ROW_SAMPLE, labels);
194+
195+
//run function
196+
Mat test(1, 4, CV_32F);
197+
Mat result;
198+
randu(test, 0, 10);
199+
rt->getVotes(test, result, 0);
200+
201+
//count vote amount and find highest vote
202+
count = 0;
203+
const int* result_row = result.ptr<int>(1);
204+
for( i = 0; i < label_size; i++ )
205+
{
206+
val = result_row[i];
207+
//predicted_class = max_votes < val? i;
208+
if( max_votes < val )
209+
{
210+
max_votes = val;
211+
predicted_class = i;
212+
}
213+
count += val;
214+
}
215+
216+
EXPECT_EQ(count, (int)rt->getRoots().size());
217+
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
218+
}
219+
175220
/* End of file. */

0 commit comments

Comments
 (0)