Skip to content

Commit 0742e12

Browse files
committed
Merge pull request opencv#10265 from dkurt:nms_for_region_layer
2 parents 9b65973 + c67e75b commit 0742e12

File tree

2 files changed

+31
-112
lines changed

2 files changed

+31
-112
lines changed

modules/dnn/src/darknet/darknet_io.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ namespace cv {
482482
}
483483
else if (layer_type == "region")
484484
{
485-
float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user
485+
float thresh = getParam<float>(layer_params, "thresh", 0.001);
486486
int coords = getParam<int>(layer_params, "coords", 4);
487487
int classes = getParam<int>(layer_params, "classes", -1);
488488
int num_of_anchors = getParam<int>(layer_params, "num", -1);

modules/dnn/src/layers/region_layer.cpp

Lines changed: 30 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
#include "../precomp.hpp"
4444
#include <opencv2/dnn/shape_utils.hpp>
4545
#include <opencv2/dnn/all_layers.hpp>
46-
#include <iostream>
46+
#include "nms.inl.hpp"
4747
#include "opencl_kernels_dnn.hpp"
4848

4949
namespace cv
@@ -173,8 +173,7 @@ class RegionLayerImpl : public RegionLayer
173173
if (nmsThreshold > 0) {
174174
Mat mat = outBlob.getMat(ACCESS_WRITE);
175175
float *dstData = mat.ptr<float>();
176-
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
177-
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
176+
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
178177
}
179178

180179
}
@@ -263,128 +262,48 @@ class RegionLayerImpl : public RegionLayer
263262
}
264263

265264
if (nmsThreshold > 0) {
266-
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
267-
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
265+
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
268266
}
269267

270268
}
271269
}
272270

273-
274-
struct box {
275-
float x, y, w, h;
276-
float *probs;
277-
};
278-
279-
float overlap(float x1, float w1, float x2, float w2)
280-
{
281-
float l1 = x1 - w1 / 2;
282-
float l2 = x2 - w2 / 2;
283-
float left = l1 > l2 ? l1 : l2;
284-
float r1 = x1 + w1 / 2;
285-
float r2 = x2 + w2 / 2;
286-
float right = r1 < r2 ? r1 : r2;
287-
return right - left;
288-
}
289-
290-
float box_intersection(box a, box b)
291-
{
292-
float w = overlap(a.x, a.w, b.x, b.w);
293-
float h = overlap(a.y, a.h, b.y, b.h);
294-
if (w < 0 || h < 0) return 0;
295-
float area = w*h;
296-
return area;
297-
}
298-
299-
float box_union(box a, box b)
271+
static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
300272
{
301-
float i = box_intersection(a, b);
302-
float u = a.w*a.h + b.w*b.h - i;
303-
return u;
273+
return 1.0f - jaccardDistance(a, b);
304274
}
305275

306-
float box_iou(box a, box b)
276+
void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
307277
{
308-
return box_intersection(a, b) / box_union(a, b);
309-
}
310-
311-
struct sortable_bbox {
312-
int index;
313-
float *probs;
314-
};
315-
316-
struct nms_comparator {
317-
int k;
318-
nms_comparator(int _k) : k(_k) {}
319-
bool operator ()(sortable_bbox v1, sortable_bbox v2) {
320-
return v2.probs[k] < v1.probs[k];
321-
}
322-
};
323-
324-
void do_nms_sort(float *detections, int total, float nms_thresh)
325-
{
326-
std::vector<box> boxes(total);
327-
for (int i = 0; i < total; ++i) {
328-
box &b = boxes[i];
329-
int box_index = i * (classes + coords + 1);
330-
b.x = detections[box_index + 0];
331-
b.y = detections[box_index + 1];
332-
b.w = detections[box_index + 2];
333-
b.h = detections[box_index + 3];
334-
int class_index = i * (classes + 5) + 5;
335-
b.probs = (detections + class_index);
336-
}
337-
338-
std::vector<sortable_bbox> s(total);
339-
340-
for (int i = 0; i < total; ++i) {
341-
s[i].index = i;
342-
int class_index = i * (classes + 5) + 5;
343-
s[i].probs = (detections + class_index);
344-
}
278+
std::vector<Rect2f> boxes(total);
279+
std::vector<float> scores(total);
345280

346-
for (int k = 0; k < classes; ++k) {
347-
std::stable_sort(s.begin(), s.end(), nms_comparator(k));
348-
for (int i = 0; i < total; ++i) {
349-
if (boxes[s[i].index].probs[k] == 0) continue;
350-
box a = boxes[s[i].index];
351-
for (int j = i + 1; j < total; ++j) {
352-
box b = boxes[s[j].index];
353-
if (box_iou(a, b) > nms_thresh) {
354-
boxes[s[j].index].probs[k] = 0;
355-
}
356-
}
357-
}
358-
}
359-
}
360-
361-
void do_nms(float *detections, int total, float nms_thresh)
362-
{
363-
std::vector<box> boxes(total);
364-
for (int i = 0; i < total; ++i) {
365-
box &b = boxes[i];
281+
for (int i = 0; i < total; ++i)
282+
{
283+
Rect2f &b = boxes[i];
366284
int box_index = i * (classes + coords + 1);
367-
b.x = detections[box_index + 0];
368-
b.y = detections[box_index + 1];
369-
b.w = detections[box_index + 2];
370-
b.h = detections[box_index + 3];
371-
int class_index = i * (classes + 5) + 5;
372-
b.probs = (detections + class_index);
285+
b.width = detections[box_index + 2];
286+
b.height = detections[box_index + 3];
287+
b.x = detections[box_index + 0] - b.width / 2;
288+
b.y = detections[box_index + 1] - b.height / 2;
373289
}
374290

375-
for (int i = 0; i < total; ++i) {
376-
bool any = false;
377-
for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
378-
if (!any) {
379-
continue;
291+
std::vector<int> indices;
292+
for (int k = 0; k < classes; ++k)
293+
{
294+
for (int i = 0; i < total; ++i)
295+
{
296+
int box_index = i * (classes + coords + 1);
297+
int class_index = box_index + 5;
298+
scores[i] = detections[class_index + k];
299+
detections[class_index + k] = 0;
380300
}
381-
for (int j = i + 1; j < total; ++j) {
382-
if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
383-
for (int k = 0; k < classes; ++k) {
384-
if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
385-
else boxes[j].probs[k] = 0;
386-
}
387-
}
301+
NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
302+
for (int i = 0, n = indices.size(); i < n; ++i)
303+
{
304+
int box_index = indices[i] * (classes + coords + 1);
305+
int class_index = box_index + 5;
306+
detections[class_index + k] = scores[indices[i]];
388307
}
389308
}
390309
}

0 commit comments

Comments
 (0)