Skip to content

Commit 6e4f943

Browse files
committed
Merge pull request opencv#9998 from alalek:ocl_fix_dnn_softmax_9991
2 parents c624967 + bacc96f commit 6e4f943

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

modules/dnn/src/layers/softmax_layer.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,26 +141,34 @@ class SoftMaxLayerImpl : public SoftmaxLayer
141141
size_t bufSize = internals[0].total();
142142
size_t totalSize = src.total();
143143

144+
// adjust local/global size
145+
size_t internal_localSize[1] = { (bufSize == 1) ? 1 : wgSize };
146+
size_t internal_globalSize[1] = { divUp(bufSize, (unsigned int)internal_localSize[0]) * internal_localSize[0] };
147+
148+
// adjust local/global size (total)
149+
size_t total_localSize[1] = { (totalSize == 1) ? 1 : wgSize };
150+
size_t total_globalSize[1] = { divUp(totalSize, (unsigned int)total_localSize[0]) * total_localSize[0] };
151+
144152
kmax.args((int)outerSize, (int)channels, (int)innerSize,
145153
ocl::KernelArg::PtrReadOnly(dstMat), ocl::KernelArg::PtrReadWrite(bufMat));
146-
if (!kmax.run(1, &bufSize, &wgSize, false))
154+
if (!kmax.run(1, internal_globalSize, internal_localSize, false))
147155
return false;
148156

149157
ksub.args((int)totalSize, (int)outerSize, (int)channels, (int)innerSize,
150158
ocl::KernelArg::PtrReadOnly(bufMat), ocl::KernelArg::PtrReadWrite(dstMat));
151-
if (!ksub.run(1, &totalSize, &wgSize, false))
159+
if (!ksub.run(1, total_globalSize, total_localSize, false))
152160
return false;
153161

154162
cv::exp(dstMat, dstMat);
155163

156164
ksum.args((int)outerSize, (int)channels, (int)innerSize,
157165
ocl::KernelArg::PtrReadOnly(dstMat), ocl::KernelArg::PtrReadWrite(bufMat));
158-
if (!ksum.run(1, &bufSize, &wgSize, false))
166+
if (!ksum.run(1, internal_globalSize, internal_localSize, false))
159167
return false;
160168

161169
kdiv.args((int)totalSize, (int)outerSize, (int)channels, (int)innerSize,
162170
ocl::KernelArg::PtrReadOnly(bufMat), ocl::KernelArg::PtrReadWrite(dstMat));
163-
if (!kdiv.run(1, &totalSize, &wgSize, false))
171+
if (!kdiv.run(1, total_globalSize, total_localSize, false))
164172
return false;
165173

166174
return true;

0 commit comments

Comments
 (0)