Skip to content

Commit d8425d8

Browse files
committed
finished for one sample
Finished with several samples support, need regression testing Gave a more relevant name to function (getVotes) Finished implicit implementation Removed printf, finished regresion testing Fixed conversion warning Finished test for Rtrees Fixed documentation Initialized variable Added doxygen documentation Added parameter name
1 parent ec47a0a commit d8425d8

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

modules/ml/include/opencv2/ml.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,17 @@ class CV_EXPORTS_W RTrees : public DTrees
11641164
*/
11651165
CV_WRAP virtual Mat getVarImportance() const = 0;
11661166

1167+
/** Returns the result of each individual tree in the forest.
1168+
In case the model is a regression problem, the method will return each of the trees'
1169+
results for each of the sample cases. If the model is a classifier, it will return
1170+
a Mat with samples + 1 rows, where the first row gives the class number and the
1171+
following rows return the votes each class had for each sample.
1172+
@param samples Array containg the samples for which votes will be calculated.
1173+
@param results Array where the result of the calculation will be written.
1174+
@param flags Flags for defining the type of RTrees.
1175+
*/
1176+
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
1177+
11671178
/** Creates the empty model.
11681179
Use StatModel::train to train the model, StatModel::train to create and train the model,
11691180
Algorithm::load to load the pre-trained model.

modules/ml/src/rtrees.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,60 @@ class DTreesImplForRTrees : public DTreesImpl
349349
}
350350
}
351351

352+
void getVotes( InputArray input, OutputArray output, int flags ) const
353+
{
354+
CV_Assert( !roots.empty() );
355+
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
356+
Mat samples = input.getMat(), results;
357+
int i, j, nsamples = samples.rows;
358+
359+
int predictType = flags & PREDICT_MASK;
360+
if( predictType == PREDICT_AUTO )
361+
{
362+
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
363+
PREDICT_SUM : PREDICT_MAX_VOTE;
364+
}
365+
366+
if( predictType == PREDICT_SUM )
367+
{
368+
output.create(nsamples, ntrees, CV_32F);
369+
results = output.getMat();
370+
for( i = 0; i < nsamples; i++ )
371+
{
372+
for( j = 0; j < ntrees; j++ )
373+
{
374+
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
375+
results.at<float> (i, j) = val;
376+
}
377+
}
378+
} else
379+
{
380+
vector<int> votes;
381+
output.create(nsamples+1, nclasses, CV_32S);
382+
results = output.getMat();
383+
384+
for ( j = 0; j < nclasses; j++)
385+
{
386+
results.at<int> (0, j) = classLabels[j];
387+
}
388+
389+
for( i = 0; i < nsamples; i++ )
390+
{
391+
votes.clear();
392+
for( j = 0; j < ntrees; j++ )
393+
{
394+
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
395+
votes.push_back(val);
396+
}
397+
398+
for ( j = 0; j < nclasses; j++)
399+
{
400+
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
401+
}
402+
}
403+
}
404+
}
405+
352406
RTreeParams rparams;
353407
double oobError;
354408
vector<float> varImportance;
@@ -401,6 +455,11 @@ class RTreesImpl : public RTrees
401455
impl.read(fn);
402456
}
403457

458+
void getVotes_( InputArray samples, OutputArray results, int flags ) const
459+
{
460+
impl.getVotes(samples, results, flags);
461+
}
462+
404463
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
405464
int getVarCount() const { return impl.getVarCount(); }
406465

@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
427486
return Algorithm::load<RTrees>(filepath, nodeName);
428487
}
429488

489+
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
490+
{
491+
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
492+
if(!this_)
493+
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
494+
return this_->getVotes_(input, output, flags);
495+
}
496+
430497
}}
431498

432499
// End of file.

modules/ml/test/test_mltests.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
172172
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
173173
}
174174

175+
TEST(ML_RTrees, getVotes)
176+
{
177+
int n = 12;
178+
int count, i;
179+
int label_size = 3;
180+
int predicted_class = 0;
181+
int max_votes = -1;
182+
int val;
183+
// RTrees for classification
184+
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
185+
186+
//data
187+
Mat data(n, 4, CV_32F);
188+
randu(data, 0, 10);
189+
190+
//labels
191+
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
192+
193+
rt->train(data, ml::ROW_SAMPLE, labels);
194+
195+
//run function
196+
Mat test(1, 4, CV_32F);
197+
Mat result;
198+
randu(test, 0, 10);
199+
rt->getVotes(test, result, 0);
200+
201+
//count vote amount and find highest vote
202+
count = 0;
203+
const int* result_row = result.ptr<int>(1);
204+
for( i = 0; i < label_size; i++ )
205+
{
206+
val = result_row[i];
207+
//predicted_class = max_votes < val? i;
208+
if( max_votes < val )
209+
{
210+
max_votes = val;
211+
predicted_class = i;
212+
}
213+
count += val;
214+
}
215+
216+
EXPECT_EQ(count, (int)rt->getRoots().size());
217+
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
218+
}
219+
175220
/* End of file. */

0 commit comments

Comments
 (0)