Skip to content

Commit b9cf65e

Browse files
committed
Parallel version of calcError in statmodel
1 parent 51cef26 commit b9cf65e

File tree

1 file changed

+69
-27
lines changed

1 file changed

+69
-27
lines changed

modules/ml/src/inner_functions.cpp

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,49 +72,91 @@ bool StatModel::train( InputArray samples, int layout, InputArray responses )
7272
return train(TrainData::create(samples, layout, responses));
7373
}
7474

75-
float StatModel::calcError( const Ptr<TrainData>& data, bool testerr, OutputArray _resp ) const
75+
class ParallelCalcError : public ParallelLoopBody
76+
{
77+
private:
78+
const Ptr<TrainData>& data;
79+
bool &testerr;
80+
Mat &resp;
81+
const StatModel &s;
82+
vector<double> &errStrip;
83+
public:
84+
ParallelCalcError(const Ptr<TrainData>& d, bool &t, Mat &_r,const StatModel &w, vector<double> &e) :
85+
data(d),
86+
testerr(t),
87+
resp(_r),
88+
s(w),
89+
errStrip(e)
90+
{
91+
}
92+
virtual void operator()(const Range& range) const
93+
{
94+
int idxErr = range.start;
95+
CV_TRACE_FUNCTION_SKIP_NESTED();
96+
Mat samples = data->getSamples();
97+
int layout = data->getLayout();
98+
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
99+
const int* sidx_ptr = sidx.ptr<int>();
100+
bool isclassifier = s.isClassifier();
101+
Mat responses = data->getResponses();
102+
int responses_type = responses.type();
103+
104+
double err = 0;
105+
106+
for (int i = range.start; i < range.end; i++)
107+
{
108+
int si = sidx_ptr ? sidx_ptr[i] : i;
109+
Mat sample = layout == ROW_SAMPLE ? samples.row(si) : samples.col(si);
110+
float val = s.predict(sample);
111+
float val0 = (responses_type == CV_32S) ? (float)responses.at<int>(si) : responses.at<float>(si);
112+
113+
if (isclassifier)
114+
err += fabs(val - val0) > FLT_EPSILON;
115+
else
116+
err += (val - val0)*(val - val0);
117+
if (!resp.empty())
118+
resp.at<float>(i) = val;
119+
}
120+
121+
122+
errStrip[idxErr]=err ;
123+
124+
};
125+
ParallelCalcError& operator=(const ParallelCalcError &) {
126+
return *this;
127+
};
128+
};
129+
130+
131+
float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray _resp) const
76132
{
77133
CV_TRACE_FUNCTION_SKIP_NESTED();
78134
Mat samples = data->getSamples();
79-
int layout = data->getLayout();
80135
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
81-
const int* sidx_ptr = sidx.ptr<int>();
82-
int i, n = (int)sidx.total();
136+
int n = (int)sidx.total();
83137
bool isclassifier = isClassifier();
84138
Mat responses = data->getResponses();
85-
int responses_type = responses.type();
86139

87-
if( n == 0 )
140+
if (n == 0)
88141
n = data->getNSamples();
89142

90-
if( n == 0 )
143+
if (n == 0)
91144
return -FLT_MAX;
92145

93146
Mat resp;
94-
if( _resp.needed() )
147+
if (_resp.needed())
95148
resp.create(n, 1, CV_32F);
96149

97150
double err = 0;
98-
for( i = 0; i < n; i++ )
99-
{
100-
int si = sidx_ptr ? sidx_ptr[i] : i;
101-
Mat sample = layout == ROW_SAMPLE ? samples.row(si) : samples.col(si);
102-
float val = predict(sample);
103-
float val0 = (responses_type == CV_32S) ? (float)responses.at<int>(si) : responses.at<float>(si);
104-
105-
if( isclassifier )
106-
err += fabs(val - val0) > FLT_EPSILON;
107-
else
108-
err += (val - val0)*(val - val0);
109-
if( !resp.empty() )
110-
resp.at<float>(i) = val;
111-
/*if( i < 100 )
112-
{
113-
printf("%d. ref %.1f vs pred %.1f\n", i, val0, val);
114-
}*/
115-
}
151+
vector<double> errStrip(n,0.0);
152+
ParallelCalcError x(data, testerr, resp, *this,errStrip);
153+
154+
parallel_for_(Range(0,n),x);
155+
156+
for (size_t i = 0; i < errStrip.size(); i++)
157+
err += errStrip[i];
116158

117-
if( _resp.needed() )
159+
if (_resp.needed())
118160
resp.copyTo(_resp);
119161

120162
return (float)(err / n * (isclassifier ? 100 : 1));

0 commit comments

Comments
 (0)