Skip to content

Commit 70c5d69

Browse files
committed
Merge pull request opencv#9310 from r2d3:svm_parallel
2 parents e012ccd + 7c334f4 commit 70c5d69

File tree

1 file changed

+115
-64
lines changed

1 file changed

+115
-64
lines changed

modules/ml/src/svm.cpp

Lines changed: 115 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,99 @@ class SVMImpl : public SVM
16361636
return true;
16371637
}
16381638

1639+
class TrainAutoBody : public ParallelLoopBody
1640+
{
1641+
public:
1642+
TrainAutoBody(const vector<SvmParams>& _parameters,
1643+
const cv::Mat& _samples,
1644+
const cv::Mat& _responses,
1645+
const cv::Mat& _labels,
1646+
const vector<int>& _sidx,
1647+
bool _is_classification,
1648+
int _k_fold,
1649+
std::vector<double>& _result) :
1650+
parameters(_parameters), samples(_samples), responses(_responses), labels(_labels),
1651+
sidx(_sidx), is_classification(_is_classification), k_fold(_k_fold), result(_result)
1652+
{}
1653+
1654+
void operator()( const cv::Range& range ) const
1655+
{
1656+
int sample_count = samples.rows;
1657+
int var_count_ = samples.cols;
1658+
size_t sample_size = var_count_*samples.elemSize();
1659+
1660+
int test_sample_count = (sample_count + k_fold/2)/k_fold;
1661+
int train_sample_count = sample_count - test_sample_count;
1662+
1663+
// Use a local instance
1664+
cv::Ptr<SVMImpl> svm = makePtr<SVMImpl>();
1665+
svm->class_labels = labels;
1666+
1667+
int rtype = responses.type();
1668+
1669+
Mat temp_train_samples(train_sample_count, var_count_, CV_32F);
1670+
Mat temp_test_samples(test_sample_count, var_count_, CV_32F);
1671+
Mat temp_train_responses(train_sample_count, 1, rtype);
1672+
Mat temp_test_responses;
1673+
1674+
for( int p = range.start; p < range.end; p++ )
1675+
{
1676+
svm->setParams(parameters[p]);
1677+
1678+
double error = 0;
1679+
for( int k = 0; k < k_fold; k++ )
1680+
{
1681+
int start = (k*sample_count + k_fold/2)/k_fold;
1682+
for( int i = 0; i < train_sample_count; i++ )
1683+
{
1684+
int j = sidx[(i+start)%sample_count];
1685+
memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
1686+
if( is_classification )
1687+
temp_train_responses.at<int>(i) = responses.at<int>(j);
1688+
else if( !responses.empty() )
1689+
temp_train_responses.at<float>(i) = responses.at<float>(j);
1690+
}
1691+
1692+
// Train SVM on <train_size> samples
1693+
if( !svm->do_train( temp_train_samples, temp_train_responses ))
1694+
continue;
1695+
1696+
for( int i = 0; i < test_sample_count; i++ )
1697+
{
1698+
int j = sidx[(i+start+train_sample_count) % sample_count];
1699+
memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
1700+
}
1701+
1702+
svm->predict(temp_test_samples, temp_test_responses, 0);
1703+
for( int i = 0; i < test_sample_count; i++ )
1704+
{
1705+
float val = temp_test_responses.at<float>(i);
1706+
int j = sidx[(i+start+train_sample_count) % sample_count];
1707+
if( is_classification )
1708+
error += (float)(val != responses.at<int>(j));
1709+
else
1710+
{
1711+
val -= responses.at<float>(j);
1712+
error += val*val;
1713+
}
1714+
}
1715+
}
1716+
1717+
result[p] = error;
1718+
}
1719+
}
1720+
1721+
private:
1722+
const vector<SvmParams>& parameters;
1723+
const cv::Mat& samples;
1724+
const cv::Mat& responses;
1725+
const cv::Mat& labels;
1726+
const vector<int>& sidx;
1727+
bool is_classification;
1728+
int k_fold;
1729+
std::vector<double>& result;
1730+
};
1731+
16391732
bool trainAuto( const Ptr<TrainData>& data, int k_fold,
16401733
ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
16411734
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
@@ -1713,15 +1806,12 @@ class SVMImpl : public SVM
17131806

17141807
int sample_count = samples.rows;
17151808
var_count = samples.cols;
1716-
size_t sample_size = var_count*samples.elemSize();
17171809

17181810
vector<int> sidx;
17191811
setRangeVector(sidx, sample_count);
17201812

1721-
int i, j, k;
1722-
17231813
// randomly permute training samples
1724-
for( i = 0; i < sample_count; i++ )
1814+
for( int i = 0; i < sample_count; i++ )
17251815
{
17261816
int i1 = rng.uniform(0, sample_count);
17271817
int i2 = rng.uniform(0, sample_count);
@@ -1735,7 +1825,7 @@ class SVMImpl : public SVM
17351825
// between the k_fold parts.
17361826
vector<int> sidx0, sidx1;
17371827

1738-
for( i = 0; i < sample_count; i++ )
1828+
for( int i = 0; i < sample_count; i++ )
17391829
{
17401830
if( responses.at<int>(sidx[i]) == 0 )
17411831
sidx0.push_back(sidx[i]);
@@ -1746,15 +1836,15 @@ class SVMImpl : public SVM
17461836
int n0 = (int)sidx0.size(), n1 = (int)sidx1.size();
17471837
int a0 = 0, a1 = 0;
17481838
sidx.clear();
1749-
for( k = 0; k < k_fold; k++ )
1839+
for( int k = 0; k < k_fold; k++ )
17501840
{
17511841
int b0 = ((k+1)*n0 + k_fold/2)/k_fold, b1 = ((k+1)*n1 + k_fold/2)/k_fold;
17521842
int a = (int)sidx.size(), b = a + (b0 - a0) + (b1 - a1);
1753-
for( i = a0; i < b0; i++ )
1843+
for( int i = a0; i < b0; i++ )
17541844
sidx.push_back(sidx0[i]);
1755-
for( i = a1; i < b1; i++ )
1845+
for( int i = a1; i < b1; i++ )
17561846
sidx.push_back(sidx1[i]);
1757-
for( i = 0; i < (b - a); i++ )
1847+
for( int i = 0; i < (b - a); i++ )
17581848
{
17591849
int i1 = rng.uniform(a, b);
17601850
int i2 = rng.uniform(a, b);
@@ -1764,75 +1854,36 @@ class SVMImpl : public SVM
17641854
}
17651855
}
17661856

1767-
int test_sample_count = (sample_count + k_fold/2)/k_fold;
1768-
int train_sample_count = sample_count - test_sample_count;
1769-
1770-
SvmParams best_params = params;
1771-
double min_error = FLT_MAX;
1772-
1773-
int rtype = responses.type();
1774-
1775-
Mat temp_train_samples(train_sample_count, var_count, CV_32F);
1776-
Mat temp_test_samples(test_sample_count, var_count, CV_32F);
1777-
Mat temp_train_responses(train_sample_count, 1, rtype);
1778-
Mat temp_test_responses;
1779-
17801857
// If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
17811858
#define FOR_IN_GRID(var, grid) \
17821859
for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )
17831860

1861+
// Create the list of parameters to test
1862+
std::vector<SvmParams> parameters;
17841863
FOR_IN_GRID(C, C_grid)
17851864
FOR_IN_GRID(gamma, gamma_grid)
17861865
FOR_IN_GRID(p, p_grid)
17871866
FOR_IN_GRID(nu, nu_grid)
17881867
FOR_IN_GRID(coef0, coef_grid)
17891868
FOR_IN_GRID(degree, degree_grid)
17901869
{
1791-
// make sure we updated the kernel and other parameters
1792-
setParams(params);
1793-
1794-
double error = 0;
1795-
for( k = 0; k < k_fold; k++ )
1796-
{
1797-
int start = (k*sample_count + k_fold/2)/k_fold;
1798-
for( i = 0; i < train_sample_count; i++ )
1799-
{
1800-
j = sidx[(i+start)%sample_count];
1801-
memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
1802-
if( is_classification )
1803-
temp_train_responses.at<int>(i) = responses.at<int>(j);
1804-
else if( !responses.empty() )
1805-
temp_train_responses.at<float>(i) = responses.at<float>(j);
1806-
}
1870+
parameters.push_back(params);
1871+
}
18071872

1808-
// Train SVM on <train_size> samples
1809-
if( !do_train( temp_train_samples, temp_train_responses ))
1810-
continue;
1873+
std::vector<double> result(parameters.size());
1874+
TrainAutoBody invoker(parameters, samples, responses, class_labels, sidx,
1875+
is_classification, k_fold, result);
1876+
parallel_for_(cv::Range(0,(int)parameters.size()), invoker);
18111877

1812-
for( i = 0; i < test_sample_count; i++ )
1813-
{
1814-
j = sidx[(i+start+train_sample_count) % sample_count];
1815-
memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
1816-
}
1817-
1818-
predict(temp_test_samples, temp_test_responses, 0);
1819-
for( i = 0; i < test_sample_count; i++ )
1820-
{
1821-
float val = temp_test_responses.at<float>(i);
1822-
j = sidx[(i+start+train_sample_count) % sample_count];
1823-
if( is_classification )
1824-
error += (float)(val != responses.at<int>(j));
1825-
else
1826-
{
1827-
val -= responses.at<float>(j);
1828-
error += val*val;
1829-
}
1830-
}
1831-
}
1832-
if( min_error > error )
1878+
// Extract the best parameters
1879+
SvmParams best_params = params;
1880+
double min_error = FLT_MAX;
1881+
for( int i = 0; i < (int)result.size(); i++ )
1882+
{
1883+
if( result[i] < min_error )
18331884
{
1834-
min_error = error;
1835-
best_params = params;
1885+
min_error = result[i];
1886+
best_params = parameters[i];
18361887
}
18371888
}
18381889

0 commit comments

Comments
 (0)