Skip to content

Commit e43997d

Browse files
LaurentBergervpisarev
authored andcommitted
Calcerror uses now weighted samples (opencv#10346)
* Calcerror uses now sample weights * catree comment in opencv#10319
1 parent b8a24b3 commit e43997d

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

modules/ml/src/ann_mlp.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ int SimulatedAnnealingSolver::run()
155155
if (newEnergy < previousEnergy)
156156
{
157157
previousEnergy = newEnergy;
158-
//??? exchange++;
158+
exchange++;
159159
}
160160
else
161161
{
@@ -405,21 +405,6 @@ class ANN_MLPImpl : public ANN_MLP_ANNEAL
405405
param2 = 0.1;
406406
params.bpMomentScale = std::min(param2, 1.);
407407
}
408-
/* else if (method == ANN_MLP::ANNEAL)
409-
{
410-
if (param1 <= 0)
411-
param1 = 10;
412-
if (param2 <= 0 || param2>param1)
413-
param2 = 0.1;
414-
if (param3 <= 0 || param3 >=1)
415-
param3 = 0.95;
416-
if (param4 <= 0)
417-
param4 = 10;
418-
params.initialT = param1;
419-
params.finalT = param2;
420-
params.coolingRatio = param3;
421-
params.itePerStep = param4;
422-
}*/
423408
}
424409

425410
int getTrainMethod() const

modules/ml/src/inner_functions.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,26 +94,29 @@ class ParallelCalcError : public ParallelLoopBody
9494
int idxErr = range.start;
9595
CV_TRACE_FUNCTION_SKIP_NESTED();
9696
Mat samples = data->getSamples();
97+
Mat weights=testerr? data->getTestSampleWeights() : data->getTrainSampleWeights();
9798
int layout = data->getLayout();
9899
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
99100
const int* sidx_ptr = sidx.ptr<int>();
100101
bool isclassifier = s.isClassifier();
101102
Mat responses = data->getResponses();
102103
int responses_type = responses.type();
103-
104104
double err = 0;
105105

106+
107+
const float* sw = weights.empty() ? 0 : weights.ptr<float>();
106108
for (int i = range.start; i < range.end; i++)
107109
{
108110
int si = sidx_ptr ? sidx_ptr[i] : i;
111+
double sweight = sw ? static_cast<double>(sw[i]) : 1.;
109112
Mat sample = layout == ROW_SAMPLE ? samples.row(si) : samples.col(si);
110113
float val = s.predict(sample);
111114
float val0 = (responses_type == CV_32S) ? (float)responses.at<int>(si) : responses.at<float>(si);
112115

113116
if (isclassifier)
114-
err += fabs(val - val0) > FLT_EPSILON;
117+
err += sweight * fabs(val - val0) > FLT_EPSILON;
115118
else
116-
err += (val - val0)*(val - val0);
119+
err += sweight * (val - val0)*(val - val0);
117120
if (!resp.empty())
118121
resp.at<float>(i) = val;
119122
}
@@ -133,12 +136,17 @@ float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray
133136
CV_TRACE_FUNCTION_SKIP_NESTED();
134137
Mat samples = data->getSamples();
135138
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
139+
Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights();
136140
int n = (int)sidx.total();
137141
bool isclassifier = isClassifier();
138142
Mat responses = data->getResponses();
139143

140144
if (n == 0)
145+
{
141146
n = data->getNSamples();
147+
weights = data->getTrainSampleWeights();
148+
testerr =false;
149+
}
142150

143151
if (n == 0)
144152
return -FLT_MAX;
@@ -155,11 +163,11 @@ float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray
155163

156164
for (size_t i = 0; i < errStrip.size(); i++)
157165
err += errStrip[i];
158-
166+
float weightSum= weights.empty() ? n: static_cast<float>(sum(weights)(0));
159167
if (_resp.needed())
160168
resp.copyTo(_resp);
161169

162-
return (float)(err / n * (isclassifier ? 100 : 1));
170+
return (float)(err/ weightSum * (isclassifier ? 100 : 1));
163171
}
164172

165173
/* Calculates upper triangular matrix S, where A is a symmetrical matrix A=S'*S */

0 commit comments

Comments
 (0)