Skip to content

Commit bd6a669

Browse files
committed
Merge pull request opencv#10150 from alalek:dnn_ocl_refactor_pooling
2 parents f071a48 + 13f3746 commit bd6a669

File tree

3 files changed

+107
-116
lines changed

3 files changed

+107
-116
lines changed

modules/dnn/src/ocl4dnn/include/ocl4dnn.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,6 @@ class OCL4DNNPool
351351
UMat& top_data,
352352
UMat& top_mask);
353353
private:
354-
UMat mask_idx_;
355-
356354
// Pooling parameters
357355
std::vector<int32_t> pad_;
358356
std::vector<int32_t> stride_;
@@ -362,7 +360,6 @@ class OCL4DNNPool
362360

363361
ocl4dnnPoolingMethod_t pool_method_;
364362
int32_t count_;
365-
int32_t batch_size_;
366363
int32_t channels_;
367364
int32_t kernel_h_;
368365
int32_t kernel_w_;

modules/dnn/src/ocl4dnn/src/ocl4dnn_pool.cpp

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config)
5454
int dims = config.in_shape.size();
5555
int spatial_dims = 2;
5656

57-
batch_size_ = config.in_shape[0];
5857
channels_ = config.channels;
5958
pool_method_ = config.pool_method;
6059

@@ -88,7 +87,7 @@ OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config)
8887
template<typename Dtype>
8988
OCL4DNNPool<Dtype>::~OCL4DNNPool()
9089
{
91-
mask_idx_.release();
90+
// nothing
9291
}
9392

9493
template<typename Dtype>
@@ -99,99 +98,100 @@ bool OCL4DNNPool<Dtype>::Forward(const UMat& bottom,
9998
bool ret = true;
10099
size_t global[] = { 128 * 128 };
101100
size_t local[] = { 128 };
102-
cl_uint argIdx = 0;
103101

104102
// support 2D case
105103
switch (pool_method_)
106104
{
107105
case LIBDNN_POOLING_METHOD_MAX:
108106
{
109-
if (top_mask.empty() && mask_idx_.empty())
110-
{
111-
mask_idx_.create(1, count_, CV_32FC1);
112-
}
113-
ocl::Kernel oclk_max_pool_forward(CL_KERNEL_SELECT("max_pool_forward"),
114-
cv::ocl::dnn::ocl4dnn_pooling_oclsrc);
107+
bool haveMask = !top_mask.empty();
108+
ocl::Kernel oclk_max_pool_forward(
109+
haveMask ? CL_KERNEL_SELECT("max_pool_forward_mask") : CL_KERNEL_SELECT("max_pool_forward"),
110+
ocl::dnn::ocl4dnn_pooling_oclsrc,
111+
format("-D KERNEL_MAX_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
112+
" -D STRIDE_W=%d -D STRIDE_H=%d"
113+
" -D PAD_W=%d -D PAD_H=%d%s",
114+
kernel_w_, kernel_h_,
115+
stride_w_, stride_h_,
116+
pad_w_, pad_h_,
117+
haveMask ? " -D HAVE_MASK=1" : ""
118+
));
115119

116120
if (oclk_max_pool_forward.empty())
117121
return false;
118122

119-
argIdx = 0;
120-
oclk_max_pool_forward.set(argIdx++, count_);
121-
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
122-
oclk_max_pool_forward.set(argIdx++, batch_size_);
123-
oclk_max_pool_forward.set(argIdx++, channels_);
124-
oclk_max_pool_forward.set(argIdx++, height_);
125-
oclk_max_pool_forward.set(argIdx++, width_);
126-
oclk_max_pool_forward.set(argIdx++, pooled_height_);
127-
oclk_max_pool_forward.set(argIdx++, pooled_width_);
128-
oclk_max_pool_forward.set(argIdx++, kernel_h_);
129-
oclk_max_pool_forward.set(argIdx++, kernel_w_);
130-
oclk_max_pool_forward.set(argIdx++, stride_h_);
131-
oclk_max_pool_forward.set(argIdx++, stride_w_);
132-
oclk_max_pool_forward.set(argIdx++, pad_h_);
133-
oclk_max_pool_forward.set(argIdx++, pad_w_);
134-
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
135-
oclk_max_pool_forward.set(argIdx++, mask_idx_.empty() ? 0 : 1);
136-
if (mask_idx_.empty())
137-
oclk_max_pool_forward.set(argIdx++, (void *)NULL);
138-
else
139-
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(mask_idx_));
140-
oclk_max_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top_mask));
123+
oclk_max_pool_forward.args(
124+
count_,
125+
ocl::KernelArg::PtrReadOnly(bottom),
126+
channels_,
127+
height_,
128+
width_,
129+
pooled_height_,
130+
pooled_width_,
131+
ocl::KernelArg::PtrWriteOnly(top),
132+
ocl::KernelArg::PtrWriteOnly(top_mask)
133+
);
141134

142135
ret = oclk_max_pool_forward.run(1, global, local, false);
143136
}
144137
break;
145138
case LIBDNN_POOLING_METHOD_AVE:
146139
{
140+
CV_Assert(top_mask.empty());
141+
147142
ocl::Kernel oclk_ave_pool_forward(CL_KERNEL_SELECT("ave_pool_forward"),
148-
cv::ocl::dnn::ocl4dnn_pooling_oclsrc);
143+
ocl::dnn::ocl4dnn_pooling_oclsrc,
144+
format("-D KERNEL_AVE_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
145+
" -D STRIDE_W=%d -D STRIDE_H=%d"
146+
" -D PAD_W=%d -D PAD_H=%d",
147+
kernel_w_, kernel_h_,
148+
stride_w_, stride_h_,
149+
pad_w_, pad_h_
150+
));
149151

150152
if (oclk_ave_pool_forward.empty())
151153
return false;
152154

153-
argIdx = 0;
154-
oclk_ave_pool_forward.set(argIdx++, count_);
155-
oclk_ave_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
156-
oclk_ave_pool_forward.set(argIdx++, batch_size_);
157-
oclk_ave_pool_forward.set(argIdx++, channels_);
158-
oclk_ave_pool_forward.set(argIdx++, height_);
159-
oclk_ave_pool_forward.set(argIdx++, width_);
160-
oclk_ave_pool_forward.set(argIdx++, pooled_height_);
161-
oclk_ave_pool_forward.set(argIdx++, pooled_width_);
162-
oclk_ave_pool_forward.set(argIdx++, kernel_h_);
163-
oclk_ave_pool_forward.set(argIdx++, kernel_w_);
164-
oclk_ave_pool_forward.set(argIdx++, stride_h_);
165-
oclk_ave_pool_forward.set(argIdx++, stride_w_);
166-
oclk_ave_pool_forward.set(argIdx++, pad_h_);
167-
oclk_ave_pool_forward.set(argIdx++, pad_w_);
168-
oclk_ave_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
155+
oclk_ave_pool_forward.args(
156+
count_,
157+
ocl::KernelArg::PtrReadOnly(bottom),
158+
channels_,
159+
height_,
160+
width_,
161+
pooled_height_,
162+
pooled_width_,
163+
ocl::KernelArg::PtrWriteOnly(top)
164+
);
169165

170166
ret = oclk_ave_pool_forward.run(1, global, local, false);
171167
}
172168
break;
173169
case LIBDNN_POOLING_METHOD_STO:
174170
{
171+
CV_Assert(top_mask.empty());
172+
175173
ocl::Kernel oclk_sto_pool_forward(CL_KERNEL_SELECT("sto_pool_forward_test"),
176-
cv::ocl::dnn::ocl4dnn_pooling_oclsrc);
174+
ocl::dnn::ocl4dnn_pooling_oclsrc,
175+
format("-D KERNEL_STO_POOL=1 -D KERNEL_W=%d -D KERNEL_H=%d"
176+
" -D STRIDE_W=%d -D STRIDE_H=%d",
177+
kernel_w_, kernel_h_,
178+
stride_w_, stride_h_
179+
));
180+
177181

178182
if (oclk_sto_pool_forward.empty())
179183
return false;
180184

181-
argIdx = 0;
182-
oclk_sto_pool_forward.set(argIdx++, count_);
183-
oclk_sto_pool_forward.set(argIdx++, ocl::KernelArg::PtrReadOnly(bottom));
184-
oclk_sto_pool_forward.set(argIdx++, batch_size_);
185-
oclk_sto_pool_forward.set(argIdx++, channels_);
186-
oclk_sto_pool_forward.set(argIdx++, height_);
187-
oclk_sto_pool_forward.set(argIdx++, width_);
188-
oclk_sto_pool_forward.set(argIdx++, pooled_height_);
189-
oclk_sto_pool_forward.set(argIdx++, pooled_width_);
190-
oclk_sto_pool_forward.set(argIdx++, kernel_h_);
191-
oclk_sto_pool_forward.set(argIdx++, kernel_w_);
192-
oclk_sto_pool_forward.set(argIdx++, stride_h_);
193-
oclk_sto_pool_forward.set(argIdx++, stride_w_);
194-
oclk_sto_pool_forward.set(argIdx++, ocl::KernelArg::PtrWriteOnly(top));
185+
oclk_sto_pool_forward.args(
186+
count_,
187+
ocl::KernelArg::PtrReadOnly(bottom),
188+
channels_,
189+
height_,
190+
width_,
191+
pooled_height_,
192+
pooled_width_,
193+
ocl::KernelArg::PtrWriteOnly(top)
194+
);
195195

196196
ret = oclk_sto_pool_forward.run(1, global, local, false);
197197
}

modules/dnn/src/opencl/ocl4dnn_pooling.cl

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,23 @@
4444
#define TEMPLATE(name,type) CONCAT(name,type)
4545
#define Dtype float
4646

47-
void TEMPLATE(max_pool_forward_impl, Dtype)(
48-
const int nthreads, __global const Dtype* bottom_data, const int num,
47+
#if defined KERNEL_MAX_POOL
48+
49+
__kernel void
50+
#ifdef HAVE_MASK
51+
TEMPLATE(max_pool_forward_mask, Dtype)
52+
#else
53+
TEMPLATE(max_pool_forward, Dtype)
54+
#endif
55+
(
56+
const int nthreads, __global const Dtype* bottom_data,
4957
const int channels, const int height, const int width,
50-
const int pooled_height, const int pooled_width, const int kernel_h,
51-
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
52-
const int pad_w,
53-
__global Dtype* top_data,
54-
const int use_mask, __global int* mask, __global Dtype* top_mask, bool no_mask)
58+
const int pooled_height, const int pooled_width,
59+
__global Dtype* top_data
60+
#ifdef HAVE_MASK
61+
, __global Dtype* mask
62+
#endif
63+
)
5564
{
5665
for (int index = get_global_id(0); index < nthreads;
5766
index += get_global_size(0))
@@ -60,10 +69,10 @@ void TEMPLATE(max_pool_forward_impl, Dtype)(
6069
const int ph = (index / pooled_width) % pooled_height;
6170
const int c = (index / pooled_width / pooled_height) % channels;
6271
const int n = index / pooled_width / pooled_height / channels;
63-
int hstart = ph * stride_h - pad_h;
64-
int wstart = pw * stride_w - pad_w;
65-
const int hend = min(hstart + kernel_h, height);
66-
const int wend = min(wstart + kernel_w, width);
72+
int hstart = ph * STRIDE_H - PAD_H;
73+
int wstart = pw * STRIDE_W - PAD_W;
74+
const int hend = min(hstart + KERNEL_H, height);
75+
const int wend = min(wstart + KERNEL_W, width);
6776
hstart = max(hstart, (int)0);
6877
wstart = max(wstart, (int)0);
6978
Dtype maxval = -FLT_MAX;
@@ -79,38 +88,19 @@ void TEMPLATE(max_pool_forward_impl, Dtype)(
7988
}
8089
}
8190
top_data[index] = maxval;
82-
if (!no_mask) {
83-
if (use_mask == 1) {
84-
mask[index] = maxidx;
85-
} else {
86-
top_mask[index] = maxidx;
87-
}
88-
}
91+
#ifdef HAVE_MASK
92+
mask[index] = maxidx;
93+
#endif
8994
}
9095
}
9196

92-
__kernel void TEMPLATE(max_pool_forward, Dtype)(
93-
const int nthreads, __global const Dtype* bottom_data, const int num,
94-
const int channels, const int height, const int width,
95-
const int pooled_height, const int pooled_width, const int kernel_h,
96-
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
97-
const int pad_w,
98-
__global Dtype* top_data,
99-
const int use_mask, __global int* mask, __global Dtype* top_mask)
100-
{
101-
TEMPLATE(max_pool_forward_impl, Dtype)(
102-
nthreads, bottom_data, num, channels, height, width,
103-
pooled_height, pooled_width, kernel_h,
104-
kernel_w, stride_h, stride_w, pad_h, pad_w, top_data, use_mask, mask, top_mask, false
105-
);
106-
}
97+
#elif defined KERNEL_AVE_POOL
10798

10899
__kernel void TEMPLATE(ave_pool_forward, Dtype)(
109-
const int nthreads, __global const Dtype* const bottom_data, const int num,
100+
const int nthreads, __global const Dtype* const bottom_data,
110101
const int channels, const int height, const int width,
111-
const int pooled_height, const int pooled_width, const int kernel_h,
112-
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
113-
const int pad_w, __global Dtype* top_data)
102+
const int pooled_height, const int pooled_width,
103+
__global Dtype* top_data)
114104
{
115105
for (int index = get_global_id(0); index < nthreads;
116106
index += get_global_size(0))
@@ -120,10 +110,10 @@ __kernel void TEMPLATE(ave_pool_forward, Dtype)(
120110
const int ph = (index / pooled_width) % pooled_height;
121111
const int c = (index / pooled_width / pooled_height) % channels;
122112
const int n = index / pooled_width / pooled_height / channels;
123-
int hstart = ph * stride_h - pad_h;
124-
int wstart = pw * stride_w - pad_w;
125-
int hend = min(hstart + kernel_h, height + pad_h);
126-
int wend = min(wstart + kernel_w, width + pad_w);
113+
int hstart = ph * STRIDE_H - PAD_H;
114+
int wstart = pw * STRIDE_W - PAD_W;
115+
int hend = min(hstart + KERNEL_H, height + PAD_H);
116+
int wend = min(wstart + KERNEL_W, width + PAD_W);
127117
const int pool_size = (hend - hstart) * (wend - wstart);
128118
hstart = max(hstart, (int)0);
129119
wstart = max(wstart, (int)0);
@@ -142,11 +132,12 @@ __kernel void TEMPLATE(ave_pool_forward, Dtype)(
142132
}
143133
}
144134

135+
#elif defined KERNEL_STO_POOL
136+
145137
__kernel void TEMPLATE(sto_pool_forward_test,Dtype)(
146-
const int nthreads, __global const Dtype* const bottom_data, const int num,
138+
const int nthreads, __global const Dtype* const bottom_data,
147139
const int channels, const int height, const int width,
148-
const int pooled_height, const int pooled_width, const int kernel_h,
149-
const int kernel_w, const int stride_h, const int stride_w,
140+
const int pooled_height, const int pooled_width,
150141
__global Dtype* top_data)
151142
{
152143
for (int index = get_global_id(0); index < nthreads;
@@ -156,10 +147,10 @@ __kernel void TEMPLATE(sto_pool_forward_test,Dtype)(
156147
const int ph = (index / pooled_width) % pooled_height;
157148
const int c = (index / pooled_width / pooled_height) % channels;
158149
const int n = index / pooled_width / pooled_height / channels;
159-
const int hstart = ph * stride_h;
160-
const int hend = min(hstart + kernel_h, height);
161-
const int wstart = pw * stride_w;
162-
const int wend = min(wstart + kernel_w, width);
150+
const int hstart = ph * STRIDE_H;
151+
const int hend = min(hstart + KERNEL_H, height);
152+
const int wstart = pw * STRIDE_W;
153+
const int wend = min(wstart + KERNEL_W, width);
163154
// We set cumsum to be 0 to avoid divide-by-zero problems
164155
Dtype cumsum = FLT_MIN;
165156
Dtype cumvalues = 0.;
@@ -168,10 +159,13 @@ __kernel void TEMPLATE(sto_pool_forward_test,Dtype)(
168159
// First pass: get sum
169160
for (int h = hstart; h < hend; ++h) {
170161
for (int w = wstart; w < wend; ++w) {
171-
cumsum += bottom_slice[h * width + w];
172-
cumvalues += bottom_slice[h * width + w] * bottom_slice[h * width + w];
162+
Dtype v = bottom_slice[h * width + w];
163+
cumsum += v;
164+
cumvalues += v * v;
173165
}
174166
}
175167
top_data[index] = cumvalues / cumsum;
176168
}
177169
}
170+
171+
#endif // KERNEL_*

0 commit comments

Comments
 (0)