-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdecode.py
70 lines (51 loc) · 2.05 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
from utils.loss.utils import _gather_feat, _tranpose_and_gather_feat
def _nms(heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
def _topk_channel(scores, K=40):
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
return topk_scores, topk_inds, topk_ys, topk_xs
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch,-1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
return topk_scores, topk_inds, topk_ys, topk_xs
def mot_decode(heat, reg=None, z=None, K=5):
batch, cat, height, width = heat.size()
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
# heat = _nms(heat)
scores, inds, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 3)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
ry=reg[:,:,2].unsqueeze(dim=2)
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
z = _tranpose_and_gather_feat(z, inds)
scores = scores.view(batch, K, 1)
xy_img_z = torch.cat([xs,ys,z,ry,scores],dim=2)
return xy_img_z
if __name__ == '__main__':
hm=torch.rand(1,1,36,56)
wh=torch.rand(1,3,36,56)*3
reg=torch.rand(1,3,36,56)
z=torch.rand(1,1,36,56)
_,_=mot_decode(hm,wh,reg,z)