@@ -185,12 +185,13 @@ class ConcatLayerImpl : public ConcatLayer
185
185
outs.getUMatVector (outputs);
186
186
187
187
int cAxis = clamp (axis, inputs[0 ].dims );
188
- if (!(cAxis == 1 && outputs[ 0 ]. dims == 4 && ! padding) )
188
+ if (padding)
189
189
return false ;
190
190
191
191
int bottom_concat_axis;
192
- int concat_size = inputs[0 ].size [2 ] * inputs[0 ].size [3 ];
193
- int top_concat_axis = outputs[0 ].size [1 ];
192
+ int concat_size = total (shape (inputs[0 ]), cAxis + 1 );
193
+ int top_concat_axis = outputs[0 ].size [cAxis];
194
+ int num_concats = total (shape (inputs[0 ]), 0 , cAxis);
194
195
int offset_concat_axis = 0 ;
195
196
UMat& outMat = outputs[0 ];
196
197
String buildopt = String (" -DDtype=" ) + ocl::typeToStr (inputs[0 ].type ()) + String (" " );
@@ -202,12 +203,12 @@ class ConcatLayerImpl : public ConcatLayer
202
203
return false ;
203
204
204
205
UMat& inpMat = inputs[i];
205
- bottom_concat_axis = inputs[i].size [1 ];
206
+ bottom_concat_axis = inputs[i].size [cAxis ];
206
207
size_t nthreads = inputs[i].total ();
207
208
208
209
kernel.set (0 , (int )nthreads);
209
210
kernel.set (1 , ocl::KernelArg::PtrReadOnly (inpMat));
210
- kernel.set (2 , (int )inputs[i]. size [ 0 ] );
211
+ kernel.set (2 , (int )num_concats );
211
212
kernel.set (3 , (int )concat_size);
212
213
kernel.set (4 , (int )top_concat_axis);
213
214
kernel.set (5 , (int )bottom_concat_axis);
0 commit comments