Skip to content

Commit f70cc29

Browse files
jutankevpisarev
authored andcommitted
export SVM::trainAuto to python opencv#7224 (opencv#8373)
* export SVM::trainAuto to python opencv#7224 * workaround for ABI compatibility of SVM::trainAuto * add parameter comments to new SVM::trainAuto function * Export ParamGrid member variables
1 parent 1857aa2 commit f70cc29

File tree

3 files changed

+103
-4
lines changed

3 files changed

+103
-4
lines changed

modules/ml/include/opencv2/ml.hpp

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ enum SampleTypes
104104
It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate
105105
being computed by cross-validation.
106106
*/
107-
class CV_EXPORTS ParamGrid
107+
class CV_EXPORTS_W ParamGrid
108108
{
109109
public:
110110
/** @brief Default constructor */
111111
ParamGrid();
112112
/** @brief Constructor with parameters */
113113
ParamGrid(double _minVal, double _maxVal, double _logStep);
114114

115-
double minVal; //!< Minimum value of the statmodel parameter. Default value is 0.
116-
double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0.
115+
CV_PROP_RW double minVal; //!< Minimum value of the statmodel parameter. Default value is 0.
116+
CV_PROP_RW double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0.
117117
/** @brief Logarithmic step for iterating the statmodel parameter.
118118
119119
The grid determines the following iteration sequence of the statmodel parameter values:
@@ -122,7 +122,15 @@ class CV_EXPORTS ParamGrid
122122
\f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f]
123123
The grid is logarithmic, so logStep must always be greater then 1. Default value is 1.
124124
*/
125-
double logStep;
125+
CV_PROP_RW double logStep;
126+
127+
/** @brief Creates a ParamGrid Ptr that can be given to the %SVM::trainAuto method
128+
129+
@param minVal minimum value of the parameter grid
130+
@param maxVal maximum value of the parameter grid
131+
@param logstep Logarithmic step for iterating the statmodel parameter
132+
*/
133+
CV_WRAP static Ptr<ParamGrid> create(double minVal=0., double maxVal=0., double logstep=1.);
126134
};
127135

128136
/** @brief Class encapsulating training data.
@@ -691,6 +699,46 @@ class CV_EXPORTS_W SVM : public StatModel
691699
ParamGrid degreeGrid = getDefaultGrid(DEGREE),
692700
bool balanced=false) = 0;
693701

702+
/** @brief Trains an %SVM with optimal parameters
703+
704+
@param samples training samples
705+
@param layout See ml::SampleTypes.
706+
@param responses vector of responses associated with the training samples.
707+
@param kFold Cross-validation parameter. The training set is divided into kFold subsets. One
708+
subset is used to test the model, the others form the train set. So, the %SVM algorithm is
709+
@param Cgrid grid for C
710+
@param gammaGrid grid for gamma
711+
@param pGrid grid for p
712+
@param nuGrid grid for nu
713+
@param coeffGrid grid for coeff
714+
@param degreeGrid grid for degree
715+
@param balanced If true and the problem is 2-class classification then the method creates more
716+
balanced cross-validation subsets that is proportions between classes in subsets are close
717+
to such proportion in the whole train dataset.
718+
719+
The method trains the %SVM model automatically by choosing the optimal parameters C, gamma, p,
720+
nu, coef0, degree. Parameters are considered optimal when the cross-validation
721+
estimate of the test set error is minimal.
722+
723+
This function only makes use of SVM::getDefaultGrid for parameter optimization and thus only
724+
offers rudimentary parameter options.
725+
726+
This function works for the classification (SVM::C_SVC or SVM::NU_SVC) as well as for the
727+
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
728+
the usual %SVM with parameters specified in params is executed.
729+
*/
730+
CV_WRAP bool trainAuto(InputArray samples,
731+
int layout,
732+
InputArray responses,
733+
int kFold = 10,
734+
Ptr<ParamGrid> Cgrid = SVM::getDefaultGridPtr(SVM::C),
735+
Ptr<ParamGrid> gammaGrid = SVM::getDefaultGridPtr(SVM::GAMMA),
736+
Ptr<ParamGrid> pGrid = SVM::getDefaultGridPtr(SVM::P),
737+
Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
738+
Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
739+
Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
740+
bool balanced=false);
741+
694742
/** @brief Retrieves all the support vectors
695743
696744
The method returns all the support vectors as a floating-point matrix, where support vectors are
@@ -733,6 +781,16 @@ class CV_EXPORTS_W SVM : public StatModel
733781
*/
734782
static ParamGrid getDefaultGrid( int param_id );
735783

784+
/** @brief Generates a grid for %SVM parameters.
785+
786+
@param param_id %SVM parameters IDs that must be one of the SVM::ParamTypes. The grid is
787+
generated for the parameter with this ID.
788+
789+
The function generates a grid pointer for the specified parameter of the %SVM algorithm.
790+
The grid may be passed to the function SVM::trainAuto.
791+
*/
792+
CV_WRAP static Ptr<ParamGrid> getDefaultGridPtr( int param_id );
793+
736794
/** Creates empty model.
737795
Use StatModel::train to train the model. Since %SVM has several parameters, you may want to
738796
find the best parameters for your problem, it can be done with SVM::trainAuto. */

modules/ml/src/inner_functions.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ ParamGrid::ParamGrid(double _minVal, double _maxVal, double _logStep)
5050
logStep = std::max(_logStep, 1.);
5151
}
5252

53+
Ptr<ParamGrid> ParamGrid::create(double minval, double maxval, double logstep) {
54+
return makePtr<ParamGrid>(minval, maxval, logstep);
55+
}
56+
5357
bool StatModel::empty() const { return !isTrained(); }
5458

5559
int StatModel::getVarCount() const { return 0; }

modules/ml/src/svm.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,12 @@ static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
362362

363363
//////////////////////// SVM implementation //////////////////////////////
364364

365+
Ptr<ParamGrid> SVM::getDefaultGridPtr( int param_id)
366+
{
367+
ParamGrid grid = getDefaultGrid(param_id); // this is not a nice solution..
368+
return makePtr<ParamGrid>(grid.minVal, grid.maxVal, grid.logStep);
369+
}
370+
365371
ParamGrid SVM::getDefaultGrid( int param_id )
366372
{
367373
ParamGrid grid;
@@ -1920,6 +1926,24 @@ class SVMImpl : public SVM
19201926
bool returnDFVal;
19211927
};
19221928

1929+
bool trainAuto_(InputArray samples, int layout,
1930+
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
1931+
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
1932+
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
1933+
{
1934+
Ptr<TrainData> data = TrainData::create(samples, layout, responses);
1935+
return this->trainAuto(
1936+
data, kfold,
1937+
*Cgrid.get(),
1938+
*gammaGrid.get(),
1939+
*pGrid.get(),
1940+
*nuGrid.get(),
1941+
*coeffGrid.get(),
1942+
*degreeGrid.get(),
1943+
balanced);
1944+
}
1945+
1946+
19231947
float predict( InputArray _samples, OutputArray _results, int flags ) const
19241948
{
19251949
float result = 0;
@@ -2281,6 +2305,19 @@ Mat SVM::getUncompressedSupportVectors() const
22812305
return this_->getUncompressedSupportVectors_();
22822306
}
22832307

2308+
bool SVM::trainAuto(InputArray samples, int layout,
2309+
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
2310+
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
2311+
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
2312+
{
2313+
SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
2314+
if (!this_) {
2315+
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
2316+
}
2317+
return this_->trainAuto_(samples, layout, responses,
2318+
kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
2319+
}
2320+
22842321
}
22852322
}
22862323

0 commit comments

Comments
 (0)