Skip to content

Commit a44573c

Browse files
committed
Add ReLU and LeakyReLU activation function in ml module
1 parent f8ad289 commit a44573c

File tree

3 files changed

+220
-74
lines changed

3 files changed

+220
-74
lines changed

modules/ml/include/opencv2/ml.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,14 +1503,18 @@ class CV_EXPORTS_W ANN_MLP : public StatModel
15031503
enum ActivationFunctions {
15041504
/** Identity function: \f$f(x)=x\f$ */
15051505
IDENTITY = 0,
1506-
/** Symmetrical sigmoid: \f$f(x)=\beta*(1-e^{-\alpha x})/(1+e^{-\alpha x}\f$
1506+
/** Symmetrical sigmoid: \f$f(x)=\beta*(1-e^{-\alpha x})/(1+e^{-\alpha x})\f$
15071507
@note
15081508
If you are using the default sigmoid activation function with the default parameter values
15091509
fparam1=0 and fparam2=0 then the function used is y = 1.7159\*tanh(2/3 \* x), so the output
15101510
will range from [-1.7159, 1.7159], instead of [0,1].*/
15111511
SIGMOID_SYM = 1,
15121512
/** Gaussian function: \f$f(x)=\beta e^{-\alpha x*x}\f$ */
1513-
GAUSSIAN = 2
1513+
GAUSSIAN = 2,
1514+
/** ReLU function: \f$f(x)=max(0,x)\f$ */
1515+
RELU = 3,
1516+
/** Leaky ReLU function: for x>0 \f$f(x)=x \f$ and x<=0 \f$f(x)=\alpha x \f$*/
1517+
LEAKYRELU= 4
15141518
};
15151519

15161520
/** Train options */

modules/ml/src/ann_mlp.cpp

Lines changed: 142 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ANN_MLPImpl : public ANN_MLP
135135

136136
void setActivationFunction(int _activ_func, double _f_param1, double _f_param2 )
137137
{
138-
if( _activ_func < 0 || _activ_func > GAUSSIAN )
138+
if( _activ_func < 0 || _activ_func > LEAKYRELU)
139139
CV_Error( CV_StsOutOfRange, "Unknown activation function" );
140140

141141
activ_func = _activ_func;
@@ -153,11 +153,23 @@ class ANN_MLPImpl : public ANN_MLP
153153
case GAUSSIAN:
154154
max_val = 1.; min_val = 0.05;
155155
max_val1 = 1.; min_val1 = 0.02;
156-
if( fabs(_f_param1) < FLT_EPSILON )
156+
if (fabs(_f_param1) < FLT_EPSILON)
157157
_f_param1 = 1.;
158-
if( fabs(_f_param2) < FLT_EPSILON )
158+
if (fabs(_f_param2) < FLT_EPSILON)
159159
_f_param2 = 1.;
160160
break;
161+
case RELU:
162+
if (fabs(_f_param1) < FLT_EPSILON)
163+
_f_param1 = 1;
164+
min_val = max_val = min_val1 = max_val1 = 0.;
165+
_f_param2 = 0.;
166+
break;
167+
case LEAKYRELU:
168+
if (fabs(_f_param1) < FLT_EPSILON)
169+
_f_param1 = 0.01;
170+
min_val = max_val = min_val1 = max_val1 = 0.;
171+
_f_param2 = 0.;
172+
break;
161173
default:
162174
min_val = max_val = min_val1 = max_val1 = 0.;
163175
_f_param1 = 1.;
@@ -368,140 +380,194 @@ class ANN_MLPImpl : public ANN_MLP
368380
}
369381
}
370382

371-
void calc_activ_func( Mat& sums, const Mat& w ) const
383+
void calc_activ_func(Mat& sums, const Mat& w) const
372384
{
373-
const double* bias = w.ptr<double>(w.rows-1);
385+
const double* bias = w.ptr<double>(w.rows - 1);
374386
int i, j, n = sums.rows, cols = sums.cols;
375387
double scale = 0, scale2 = f_param2;
376388

377-
switch( activ_func )
389+
switch (activ_func)
378390
{
379-
case IDENTITY:
380-
scale = 1.;
381-
break;
382-
case SIGMOID_SYM:
383-
scale = -f_param1;
384-
break;
385-
case GAUSSIAN:
386-
scale = -f_param1*f_param1;
387-
break;
388-
default:
389-
;
391+
case IDENTITY:
392+
scale = 1.;
393+
break;
394+
case SIGMOID_SYM:
395+
scale = -f_param1;
396+
break;
397+
case GAUSSIAN:
398+
scale = -f_param1*f_param1;
399+
break;
400+
case RELU:
401+
scale = 1;
402+
break;
403+
case LEAKYRELU:
404+
scale = 1;
405+
break;
406+
default:
407+
;
390408
}
391409

392-
CV_Assert( sums.isContinuous() );
410+
CV_Assert(sums.isContinuous());
393411

394-
if( activ_func != GAUSSIAN )
412+
if (activ_func != GAUSSIAN)
395413
{
396-
for( i = 0; i < n; i++ )
414+
for (i = 0; i < n; i++)
397415
{
398416
double* data = sums.ptr<double>(i);
399-
for( j = 0; j < cols; j++ )
417+
for (j = 0; j < cols; j++)
418+
{
400419
data[j] = (data[j] + bias[j])*scale;
420+
if (activ_func == RELU)
421+
if (data[j] < 0)
422+
data[j] = 0;
423+
if (activ_func == LEAKYRELU)
424+
if (data[j] < 0)
425+
data[j] *= f_param1;
426+
}
401427
}
402428

403-
if( activ_func == IDENTITY )
429+
if (activ_func == IDENTITY || activ_func == RELU || activ_func == LEAKYRELU)
404430
return;
405431
}
406432
else
407433
{
408-
for( i = 0; i < n; i++ )
434+
for (i = 0; i < n; i++)
409435
{
410436
double* data = sums.ptr<double>(i);
411-
for( j = 0; j < cols; j++ )
437+
for (j = 0; j < cols; j++)
412438
{
413439
double t = data[j] + bias[j];
414440
data[j] = t*t*scale;
415441
}
416442
}
417443
}
418444

419-
exp( sums, sums );
445+
exp(sums, sums);
420446

421-
if( sums.isContinuous() )
447+
if (sums.isContinuous())
422448
{
423449
cols *= n;
424450
n = 1;
425451
}
426452

427-
switch( activ_func )
453+
switch (activ_func)
428454
{
429-
case SIGMOID_SYM:
430-
for( i = 0; i < n; i++ )
455+
case SIGMOID_SYM:
456+
for (i = 0; i < n; i++)
457+
{
458+
double* data = sums.ptr<double>(i);
459+
for (j = 0; j < cols; j++)
431460
{
432-
double* data = sums.ptr<double>(i);
433-
for( j = 0; j < cols; j++ )
461+
if (!cvIsInf(data[j]))
434462
{
435-
if(!cvIsInf(data[j]))
436-
{
437-
double t = scale2*(1. - data[j])/(1. + data[j]);
438-
data[j] = t;
439-
}
440-
else
441-
{
442-
data[j] = -scale2;
443-
}
463+
double t = scale2*(1. - data[j]) / (1. + data[j]);
464+
data[j] = t;
465+
}
466+
else
467+
{
468+
data[j] = -scale2;
444469
}
445470
}
446-
break;
471+
}
472+
break;
447473

448-
case GAUSSIAN:
449-
for( i = 0; i < n; i++ )
450-
{
451-
double* data = sums.ptr<double>(i);
452-
for( j = 0; j < cols; j++ )
453-
data[j] = scale2*data[j];
454-
}
455-
break;
474+
case GAUSSIAN:
475+
for (i = 0; i < n; i++)
476+
{
477+
double* data = sums.ptr<double>(i);
478+
for (j = 0; j < cols; j++)
479+
data[j] = scale2*data[j];
480+
}
481+
break;
456482

457-
default:
458-
;
483+
default:
484+
;
459485
}
460486
}
461487

462-
void calc_activ_func_deriv( Mat& _xf, Mat& _df, const Mat& w ) const
488+
void calc_activ_func_deriv(Mat& _xf, Mat& _df, const Mat& w) const
463489
{
464-
const double* bias = w.ptr<double>(w.rows-1);
490+
const double* bias = w.ptr<double>(w.rows - 1);
465491
int i, j, n = _xf.rows, cols = _xf.cols;
466492

467-
if( activ_func == IDENTITY )
493+
if (activ_func == IDENTITY)
468494
{
469-
for( i = 0; i < n; i++ )
495+
for (i = 0; i < n; i++)
470496
{
471497
double* xf = _xf.ptr<double>(i);
472498
double* df = _df.ptr<double>(i);
473499

474-
for( j = 0; j < cols; j++ )
500+
for (j = 0; j < cols; j++)
475501
{
476502
xf[j] += bias[j];
477503
df[j] = 1;
478504
}
479505
}
480506
}
481-
else if( activ_func == GAUSSIAN )
507+
else if (activ_func == RELU)
508+
{
509+
for (i = 0; i < n; i++)
510+
{
511+
double* xf = _xf.ptr<double>(i);
512+
double* df = _df.ptr<double>(i);
513+
514+
for (j = 0; j < cols; j++)
515+
{
516+
xf[j] += bias[j];
517+
if (xf[j] < 0)
518+
{
519+
xf[j] = 0;
520+
df[j] = 0;
521+
}
522+
else
523+
df[j] = 1;
524+
}
525+
}
526+
}
527+
else if (activ_func == LEAKYRELU)
528+
{
529+
for (i = 0; i < n; i++)
530+
{
531+
double* xf = _xf.ptr<double>(i);
532+
double* df = _df.ptr<double>(i);
533+
534+
for (j = 0; j < cols; j++)
535+
{
536+
xf[j] += bias[j];
537+
if (xf[j] < 0)
538+
{
539+
xf[j] = f_param1*xf[j];
540+
df[j] = f_param1;
541+
}
542+
else
543+
df[j] = 1;
544+
}
545+
}
546+
}
547+
else if (activ_func == GAUSSIAN)
482548
{
483549
double scale = -f_param1*f_param1;
484550
double scale2 = scale*f_param2;
485-
for( i = 0; i < n; i++ )
551+
for (i = 0; i < n; i++)
486552
{
487553
double* xf = _xf.ptr<double>(i);
488554
double* df = _df.ptr<double>(i);
489555

490-
for( j = 0; j < cols; j++ )
556+
for (j = 0; j < cols; j++)
491557
{
492558
double t = xf[j] + bias[j];
493-
df[j] = t*2*scale2;
559+
df[j] = t * 2 * scale2;
494560
xf[j] = t*t*scale;
495561
}
496562
}
497-
exp( _xf, _xf );
563+
exp(_xf, _xf);
498564

499-
for( i = 0; i < n; i++ )
565+
for (i = 0; i < n; i++)
500566
{
501567
double* xf = _xf.ptr<double>(i);
502568
double* df = _df.ptr<double>(i);
503569

504-
for( j = 0; j < cols; j++ )
570+
for (j = 0; j < cols; j++)
505571
df[j] *= xf[j];
506572
}
507573
}
@@ -510,34 +576,34 @@ class ANN_MLPImpl : public ANN_MLP
510576
double scale = f_param1;
511577
double scale2 = f_param2;
512578

513-
for( i = 0; i < n; i++ )
579+
for (i = 0; i < n; i++)
514580
{
515581
double* xf = _xf.ptr<double>(i);
516582
double* df = _df.ptr<double>(i);
517583

518-
for( j = 0; j < cols; j++ )
584+
for (j = 0; j < cols; j++)
519585
{
520586
xf[j] = (xf[j] + bias[j])*scale;
521587
df[j] = -fabs(xf[j]);
522588
}
523589
}
524590

525-
exp( _df, _df );
591+
exp(_df, _df);
526592

527593
// ((1+exp(-ax))^-1)'=a*((1+exp(-ax))^-2)*exp(-ax);
528594
// ((1-exp(-ax))/(1+exp(-ax)))'=(a*exp(-ax)*(1+exp(-ax)) + a*exp(-ax)*(1-exp(-ax)))/(1+exp(-ax))^2=
529595
// 2*a*exp(-ax)/(1+exp(-ax))^2
530-
scale *= 2*f_param2;
531-
for( i = 0; i < n; i++ )
596+
scale *= 2 * f_param2;
597+
for (i = 0; i < n; i++)
532598
{
533599
double* xf = _xf.ptr<double>(i);
534600
double* df = _df.ptr<double>(i);
535601

536-
for( j = 0; j < cols; j++ )
602+
for (j = 0; j < cols; j++)
537603
{
538604
int s0 = xf[j] > 0 ? 1 : -1;
539-
double t0 = 1./(1. + df[j]);
540-
double t1 = scale*df[j]*t0*t0;
605+
double t0 = 1. / (1. + df[j]);
606+
double t1 = scale*df[j] * t0*t0;
541607
t0 *= scale2*(1. - df[j])*s0;
542608
df[j] = t1;
543609
xf[j] = t0;
@@ -1110,7 +1176,9 @@ class ANN_MLPImpl : public ANN_MLP
11101176
{
11111177
const char* activ_func_name = activ_func == IDENTITY ? "IDENTITY" :
11121178
activ_func == SIGMOID_SYM ? "SIGMOID_SYM" :
1113-
activ_func == GAUSSIAN ? "GAUSSIAN" : 0;
1179+
activ_func == GAUSSIAN ? "GAUSSIAN" :
1180+
activ_func == RELU ? "RELU" :
1181+
activ_func == LEAKYRELU ? "LEAKYRELU" : 0;
11141182

11151183
if( activ_func_name )
11161184
fs << "activation_function" << activ_func_name;
@@ -1191,6 +1259,8 @@ class ANN_MLPImpl : public ANN_MLP
11911259
{
11921260
activ_func = activ_func_name == "SIGMOID_SYM" ? SIGMOID_SYM :
11931261
activ_func_name == "IDENTITY" ? IDENTITY :
1262+
activ_func_name == "RELU" ? RELU :
1263+
activ_func_name == "LEAKYRELU" ? LEAKYRELU :
11941264
activ_func_name == "GAUSSIAN" ? GAUSSIAN : -1;
11951265
CV_Assert( activ_func >= 0 );
11961266
}

0 commit comments

Comments
 (0)