@@ -104,16 +104,16 @@ enum SampleTypes
104
104
It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate
105
105
being computed by cross-validation.
106
106
*/
107
- class CV_EXPORTS ParamGrid
107
+ class CV_EXPORTS_W ParamGrid
108
108
{
109
109
public:
110
110
/* * @brief Default constructor */
111
111
ParamGrid ();
112
112
/* * @brief Constructor with parameters */
113
113
ParamGrid (double _minVal, double _maxVal, double _logStep);
114
114
115
- double minVal; // !< Minimum value of the statmodel parameter. Default value is 0.
116
- double maxVal; // !< Maximum value of the statmodel parameter. Default value is 0.
115
+ CV_PROP_RW double minVal; // !< Minimum value of the statmodel parameter. Default value is 0.
116
+ CV_PROP_RW double maxVal; // !< Maximum value of the statmodel parameter. Default value is 0.
117
117
/* * @brief Logarithmic step for iterating the statmodel parameter.
118
118
119
119
The grid determines the following iteration sequence of the statmodel parameter values:
@@ -122,7 +122,15 @@ class CV_EXPORTS ParamGrid
122
122
\f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f]
123
123
The grid is logarithmic, so logStep must always be greater then 1. Default value is 1.
124
124
*/
125
- double logStep;
125
+ CV_PROP_RW double logStep;
126
+
127
+ /* * @brief Creates a ParamGrid Ptr that can be given to the %SVM::trainAuto method
128
+
129
+ @param minVal minimum value of the parameter grid
130
+ @param maxVal maximum value of the parameter grid
131
+ @param logstep Logarithmic step for iterating the statmodel parameter
132
+ */
133
+ CV_WRAP static Ptr<ParamGrid> create (double minVal=0 ., double maxVal=0 ., double logstep=1 .);
126
134
};
127
135
128
136
/* * @brief Class encapsulating training data.
@@ -691,6 +699,46 @@ class CV_EXPORTS_W SVM : public StatModel
691
699
ParamGrid degreeGrid = getDefaultGrid(DEGREE),
692
700
bool balanced=false) = 0;
693
701
702
+ /* * @brief Trains an %SVM with optimal parameters
703
+
704
+ @param samples training samples
705
+ @param layout See ml::SampleTypes.
706
+ @param responses vector of responses associated with the training samples.
707
+ @param kFold Cross-validation parameter. The training set is divided into kFold subsets. One
708
+ subset is used to test the model, the others form the train set. So, the %SVM algorithm is
709
+ @param Cgrid grid for C
710
+ @param gammaGrid grid for gamma
711
+ @param pGrid grid for p
712
+ @param nuGrid grid for nu
713
+ @param coeffGrid grid for coeff
714
+ @param degreeGrid grid for degree
715
+ @param balanced If true and the problem is 2-class classification then the method creates more
716
+ balanced cross-validation subsets that is proportions between classes in subsets are close
717
+ to such proportion in the whole train dataset.
718
+
719
+ The method trains the %SVM model automatically by choosing the optimal parameters C, gamma, p,
720
+ nu, coef0, degree. Parameters are considered optimal when the cross-validation
721
+ estimate of the test set error is minimal.
722
+
723
+ This function only makes use of SVM::getDefaultGrid for parameter optimization and thus only
724
+ offers rudimentary parameter options.
725
+
726
+ This function works for the classification (SVM::C_SVC or SVM::NU_SVC) as well as for the
727
+ regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
728
+ the usual %SVM with parameters specified in params is executed.
729
+ */
730
+ CV_WRAP bool trainAuto (InputArray samples,
731
+ int layout,
732
+ InputArray responses,
733
+ int kFold = 10 ,
734
+ Ptr<ParamGrid> Cgrid = SVM::getDefaultGridPtr(SVM::C),
735
+ Ptr<ParamGrid> gammaGrid = SVM::getDefaultGridPtr(SVM::GAMMA),
736
+ Ptr<ParamGrid> pGrid = SVM::getDefaultGridPtr(SVM::P),
737
+ Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
738
+ Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
739
+ Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
740
+ bool balanced=false);
741
+
694
742
/* * @brief Retrieves all the support vectors
695
743
696
744
The method returns all the support vectors as a floating-point matrix, where support vectors are
@@ -733,6 +781,16 @@ class CV_EXPORTS_W SVM : public StatModel
733
781
*/
734
782
static ParamGrid getDefaultGrid ( int param_id );
735
783
784
+ /* * @brief Generates a grid for %SVM parameters.
785
+
786
+ @param param_id %SVM parameters IDs that must be one of the SVM::ParamTypes. The grid is
787
+ generated for the parameter with this ID.
788
+
789
+ The function generates a grid pointer for the specified parameter of the %SVM algorithm.
790
+ The grid may be passed to the function SVM::trainAuto.
791
+ */
792
+ CV_WRAP static Ptr<ParamGrid> getDefaultGridPtr ( int param_id );
793
+
736
794
/* * Creates empty model.
737
795
Use StatModel::train to train the model. Since %SVM has several parameters, you may want to
738
796
find the best parameters for your problem, it can be done with SVM::trainAuto. */
0 commit comments