@@ -56,6 +56,7 @@ class ConcatLayerImpl : public ConcatLayer
56
56
{
57
57
setParamsFrom (params);
58
58
axis = params.get <int >(" axis" , 1 );
59
+ padding = params.get <bool >(" padding" , false );
59
60
}
60
61
61
62
virtual bool getMemoryShapes (const std::vector<MatShape> &inputs,
@@ -64,34 +65,41 @@ class ConcatLayerImpl : public ConcatLayer
64
65
std::vector<MatShape> &internals) const
65
66
{
66
67
CV_Assert (inputs.size () > 0 );
67
- outputs.clear ();
68
- outputs.push_back (inputs[0 ]);
68
+ outputs.resize (1 , inputs[0 ]);
69
69
int cAxis = clamp (axis, inputs[0 ]);
70
70
71
71
int axisSum = 0 ;
72
72
for (size_t i = 0 ; i < inputs.size (); i++)
73
73
{
74
74
MatShape curShape = inputs[i];
75
75
76
- CV_Assert (curShape.size () == outputs.back ().size ());
77
- for (int curAxis = 0 ; curAxis < outputs.back ().size (); curAxis++)
76
+ if (padding)
78
77
{
79
- if (curAxis != cAxis && outputs.back ()[curAxis] != curShape[curAxis])
80
- CV_Error (Error::StsBadSize, " Inconsitent shape for ConcatLayer" );
78
+ for (int curAxis = 0 ; curAxis < outputs[0 ].size (); curAxis++)
79
+ {
80
+ outputs[0 ][curAxis] = std::max (outputs[0 ][curAxis], curShape[curAxis]);
81
+ }
82
+ }
83
+ else
84
+ {
85
+ CV_Assert (curShape.size () == outputs[0 ].size ());
86
+ for (int curAxis = 0 ; curAxis < outputs[0 ].size (); curAxis++)
87
+ {
88
+ if (curAxis != cAxis && outputs[0 ][curAxis] != curShape[curAxis])
89
+ CV_Error (Error::StsBadSize, " Inconsitent shape for ConcatLayer" );
90
+ }
81
91
}
82
92
83
93
axisSum += curShape[cAxis];
84
94
}
85
-
86
- outputs.back ()[cAxis] = axisSum;
87
-
95
+ outputs[0 ][cAxis] = axisSum;
88
96
return false ;
89
97
}
90
98
91
99
virtual bool supportBackend (int backendId)
92
100
{
93
101
return backendId == DNN_BACKEND_DEFAULT ||
94
- backendId == DNN_BACKEND_HALIDE && haveHalide () && axis == 1 ; // By channels
102
+ backendId == DNN_BACKEND_HALIDE && haveHalide () && axis == 1 && !padding ; // By channels
95
103
}
96
104
97
105
class ChannelConcatInvoker : public ParallelLoopBody
@@ -174,7 +182,10 @@ class ConcatLayerImpl : public ConcatLayer
174
182
int cAxis = clamp (axis, inputs[0 ]->dims );
175
183
Mat& outMat = outputs[0 ];
176
184
177
- if ( cAxis == 1 && outMat.dims == 4 )
185
+ if (padding)
186
+ outMat.setTo (0 );
187
+
188
+ if ( cAxis == 1 && outMat.dims == 4 && !padding)
178
189
{
179
190
int nstripes = getNumThreads ();
180
191
ChannelConcatInvoker::run (inputs, outMat, nstripes);
@@ -187,6 +198,12 @@ class ConcatLayerImpl : public ConcatLayer
187
198
for (size_t i = 0 ; i < inputs.size (); i++)
188
199
{
189
200
ranges[cAxis].end = ranges[cAxis].start + inputs[i]->size [cAxis];
201
+ for (int j = 0 ; j < outMat.dims ; ++j)
202
+ {
203
+ if (j == cAxis) continue ;
204
+ ranges[j].start = (outMat.size [j] - inputs[i]->size [j]) / 2 ;
205
+ ranges[j].end = ranges[j].start + inputs[i]->size [j];
206
+ }
190
207
inputs[i]->copyTo (outMat (&ranges[0 ]));
191
208
ranges[cAxis].start = ranges[cAxis].end ;
192
209
}
0 commit comments