Skip to content

Commit 85b1c40

Browse files
committed
support axis in concat layer ocl path
Signed-off-by: Li Peng <peng.li@intel.com>
1 parent 07bec6b commit 85b1c40

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

modules/dnn/src/layers/concat_layer.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,13 @@ class ConcatLayerImpl : public ConcatLayer
185185
outs.getUMatVector(outputs);
186186

187187
int cAxis = clamp(axis, inputs[0].dims);
188-
if (!(cAxis == 1 && outputs[0].dims == 4 && !padding))
188+
if (padding)
189189
return false;
190190

191191
int bottom_concat_axis;
192-
int concat_size = inputs[0].size[2] * inputs[0].size[3];
193-
int top_concat_axis = outputs[0].size[1];
192+
int concat_size = total(shape(inputs[0]), cAxis + 1);
193+
int top_concat_axis = outputs[0].size[cAxis];
194+
int num_concats = total(shape(inputs[0]), 0, cAxis);
194195
int offset_concat_axis = 0;
195196
UMat& outMat = outputs[0];
196197
String buildopt = String("-DDtype=") + ocl::typeToStr(inputs[0].type()) + String(" ");
@@ -202,12 +203,12 @@ class ConcatLayerImpl : public ConcatLayer
202203
return false;
203204

204205
UMat& inpMat = inputs[i];
205-
bottom_concat_axis = inputs[i].size[1];
206+
bottom_concat_axis = inputs[i].size[cAxis];
206207
size_t nthreads = inputs[i].total();
207208

208209
kernel.set(0, (int)nthreads);
209210
kernel.set(1, ocl::KernelArg::PtrReadOnly(inpMat));
210-
kernel.set(2, (int)inputs[i].size[0]);
211+
kernel.set(2, (int)num_concats);
211212
kernel.set(3, (int)concat_size);
212213
kernel.set(4, (int)top_concat_axis);
213214
kernel.set(5, (int)bottom_concat_axis);

0 commit comments

Comments
 (0)