|
43 | 43 | #include "../precomp.hpp"
|
44 | 44 | #include <opencv2/dnn/shape_utils.hpp>
|
45 | 45 | #include <opencv2/dnn/all_layers.hpp>
|
46 |
| -#include <iostream> |
| 46 | +#include "nms.inl.hpp" |
47 | 47 | #include "opencl_kernels_dnn.hpp"
|
48 | 48 |
|
49 | 49 | namespace cv
|
@@ -173,8 +173,7 @@ class RegionLayerImpl : public RegionLayer
|
173 | 173 | if (nmsThreshold > 0) {
|
174 | 174 | Mat mat = outBlob.getMat(ACCESS_WRITE);
|
175 | 175 | float *dstData = mat.ptr<float>();
|
176 |
| - do_nms_sort(dstData, rows*cols*anchors, nmsThreshold); |
177 |
| - //do_nms(dstData, rows*cols*anchors, nmsThreshold); |
| 176 | + do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold); |
178 | 177 | }
|
179 | 178 |
|
180 | 179 | }
|
@@ -263,128 +262,48 @@ class RegionLayerImpl : public RegionLayer
|
263 | 262 | }
|
264 | 263 |
|
265 | 264 | if (nmsThreshold > 0) {
|
266 |
| - do_nms_sort(dstData, rows*cols*anchors, nmsThreshold); |
267 |
| - //do_nms(dstData, rows*cols*anchors, nmsThreshold); |
| 265 | + do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold); |
268 | 266 | }
|
269 | 267 |
|
270 | 268 | }
|
271 | 269 | }
|
272 | 270 |
|
273 |
| - |
274 |
| - struct box { |
275 |
| - float x, y, w, h; |
276 |
| - float *probs; |
277 |
| - }; |
278 |
| - |
279 |
| - float overlap(float x1, float w1, float x2, float w2) |
280 |
| - { |
281 |
| - float l1 = x1 - w1 / 2; |
282 |
| - float l2 = x2 - w2 / 2; |
283 |
| - float left = l1 > l2 ? l1 : l2; |
284 |
| - float r1 = x1 + w1 / 2; |
285 |
| - float r2 = x2 + w2 / 2; |
286 |
| - float right = r1 < r2 ? r1 : r2; |
287 |
| - return right - left; |
288 |
| - } |
289 |
| - |
290 |
| - float box_intersection(box a, box b) |
291 |
| - { |
292 |
| - float w = overlap(a.x, a.w, b.x, b.w); |
293 |
| - float h = overlap(a.y, a.h, b.y, b.h); |
294 |
| - if (w < 0 || h < 0) return 0; |
295 |
| - float area = w*h; |
296 |
| - return area; |
297 |
| - } |
298 |
| - |
299 |
| - float box_union(box a, box b) |
| 271 | + static inline float rectOverlap(const Rect2f& a, const Rect2f& b) |
300 | 272 | {
|
301 |
| - float i = box_intersection(a, b); |
302 |
| - float u = a.w*a.h + b.w*b.h - i; |
303 |
| - return u; |
| 273 | + return 1.0f - jaccardDistance(a, b); |
304 | 274 | }
|
305 | 275 |
|
306 |
| - float box_iou(box a, box b) |
| 276 | + void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh) |
307 | 277 | {
|
308 |
| - return box_intersection(a, b) / box_union(a, b); |
309 |
| - } |
310 |
| - |
311 |
| - struct sortable_bbox { |
312 |
| - int index; |
313 |
| - float *probs; |
314 |
| - }; |
315 |
| - |
316 |
| - struct nms_comparator { |
317 |
| - int k; |
318 |
| - nms_comparator(int _k) : k(_k) {} |
319 |
| - bool operator ()(sortable_bbox v1, sortable_bbox v2) { |
320 |
| - return v2.probs[k] < v1.probs[k]; |
321 |
| - } |
322 |
| - }; |
323 |
| - |
324 |
| - void do_nms_sort(float *detections, int total, float nms_thresh) |
325 |
| - { |
326 |
| - std::vector<box> boxes(total); |
327 |
| - for (int i = 0; i < total; ++i) { |
328 |
| - box &b = boxes[i]; |
329 |
| - int box_index = i * (classes + coords + 1); |
330 |
| - b.x = detections[box_index + 0]; |
331 |
| - b.y = detections[box_index + 1]; |
332 |
| - b.w = detections[box_index + 2]; |
333 |
| - b.h = detections[box_index + 3]; |
334 |
| - int class_index = i * (classes + 5) + 5; |
335 |
| - b.probs = (detections + class_index); |
336 |
| - } |
337 |
| - |
338 |
| - std::vector<sortable_bbox> s(total); |
339 |
| - |
340 |
| - for (int i = 0; i < total; ++i) { |
341 |
| - s[i].index = i; |
342 |
| - int class_index = i * (classes + 5) + 5; |
343 |
| - s[i].probs = (detections + class_index); |
344 |
| - } |
| 278 | + std::vector<Rect2f> boxes(total); |
| 279 | + std::vector<float> scores(total); |
345 | 280 |
|
346 |
| - for (int k = 0; k < classes; ++k) { |
347 |
| - std::stable_sort(s.begin(), s.end(), nms_comparator(k)); |
348 |
| - for (int i = 0; i < total; ++i) { |
349 |
| - if (boxes[s[i].index].probs[k] == 0) continue; |
350 |
| - box a = boxes[s[i].index]; |
351 |
| - for (int j = i + 1; j < total; ++j) { |
352 |
| - box b = boxes[s[j].index]; |
353 |
| - if (box_iou(a, b) > nms_thresh) { |
354 |
| - boxes[s[j].index].probs[k] = 0; |
355 |
| - } |
356 |
| - } |
357 |
| - } |
358 |
| - } |
359 |
| - } |
360 |
| - |
361 |
| - void do_nms(float *detections, int total, float nms_thresh) |
362 |
| - { |
363 |
| - std::vector<box> boxes(total); |
364 |
| - for (int i = 0; i < total; ++i) { |
365 |
| - box &b = boxes[i]; |
| 281 | + for (int i = 0; i < total; ++i) |
| 282 | + { |
| 283 | + Rect2f &b = boxes[i]; |
366 | 284 | int box_index = i * (classes + coords + 1);
|
367 |
| - b.x = detections[box_index + 0]; |
368 |
| - b.y = detections[box_index + 1]; |
369 |
| - b.w = detections[box_index + 2]; |
370 |
| - b.h = detections[box_index + 3]; |
371 |
| - int class_index = i * (classes + 5) + 5; |
372 |
| - b.probs = (detections + class_index); |
| 285 | + b.width = detections[box_index + 2]; |
| 286 | + b.height = detections[box_index + 3]; |
| 287 | + b.x = detections[box_index + 0] - b.width / 2; |
| 288 | + b.y = detections[box_index + 1] - b.height / 2; |
373 | 289 | }
|
374 | 290 |
|
375 |
| - for (int i = 0; i < total; ++i) { |
376 |
| - bool any = false; |
377 |
| - for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0); |
378 |
| - if (!any) { |
379 |
| - continue; |
| 291 | + std::vector<int> indices; |
| 292 | + for (int k = 0; k < classes; ++k) |
| 293 | + { |
| 294 | + for (int i = 0; i < total; ++i) |
| 295 | + { |
| 296 | + int box_index = i * (classes + coords + 1); |
| 297 | + int class_index = box_index + 5; |
| 298 | + scores[i] = detections[class_index + k]; |
| 299 | + detections[class_index + k] = 0; |
380 | 300 | }
|
381 |
| - for (int j = i + 1; j < total; ++j) { |
382 |
| - if (box_iou(boxes[i], boxes[j]) > nms_thresh) { |
383 |
| - for (int k = 0; k < classes; ++k) { |
384 |
| - if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0; |
385 |
| - else boxes[j].probs[k] = 0; |
386 |
| - } |
387 |
| - } |
| 301 | + NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap); |
| 302 | + for (int i = 0, n = indices.size(); i < n; ++i) |
| 303 | + { |
| 304 | + int box_index = indices[i] * (classes + coords + 1); |
| 305 | + int class_index = box_index + 5; |
| 306 | + detections[class_index + k] = scores[indices[i]]; |
388 | 307 | }
|
389 | 308 | }
|
390 | 309 | }
|
|
0 commit comments