Skip to content

Commit 62c521d

Browse files
committed
kdnet finished
1 parent 5272b6d commit 62c521d

File tree

2 files changed

+223
-31
lines changed

2 files changed

+223
-31
lines changed

playground.ipynb

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@
3535
"d = PartDataset(root = '../unsupervised3d/shapenetcore_partanno_segmentation_benchmark_v0', classification = True)"
3636
]
3737
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": 151,
41+
"metadata": {
42+
"collapsed": false
43+
},
44+
"outputs": [
45+
{
46+
"data": {
47+
"text/plain": [
48+
"16"
49+
]
50+
},
51+
"execution_count": 151,
52+
"metadata": {},
53+
"output_type": "execute_result"
54+
}
55+
],
56+
"source": [
57+
"len(d.classes)"
58+
]
59+
},
3860
{
3961
"cell_type": "code",
4062
"execution_count": 3,
@@ -60,7 +82,7 @@
6082
},
6183
{
6284
"cell_type": "code",
63-
"execution_count": 95,
85+
"execution_count": 5,
6486
"metadata": {
6587
"collapsed": false
6688
},
@@ -95,7 +117,7 @@
95117
},
96118
{
97119
"cell_type": "code",
98-
"execution_count": 145,
120+
"execution_count": 34,
99121
"metadata": {
100122
"collapsed": false
101123
},
@@ -104,8 +126,8 @@
104126
"name": "stdout",
105127
"output_type": "stream",
106128
"text": [
107-
"CPU times: user 1.23 s, sys: 12 ms, total: 1.24 s\n",
108-
"Wall time: 1.24 s\n"
129+
"CPU times: user 1.28 s, sys: 0 ns, total: 1.28 s\n",
130+
"Wall time: 1.28 s\n"
109131
]
110132
}
111133
],
@@ -124,7 +146,7 @@
124146
" tree[level+1].append(right_ps)\n",
125147
" cutdim[level].append(dim) \n",
126148
" cutdim[level].append(dim) \n",
127-
" cutdim = [Variable(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n",
149+
" cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n",
128150
" points = torch.stack(tree[-1])\n",
129151
" \n",
130152
" \n",
@@ -133,33 +155,62 @@
133155
},
134156
{
135157
"cell_type": "code",
136-
"execution_count": 165,
158+
"execution_count": 174,
137159
"metadata": {
138160
"collapsed": false
139161
},
140162
"outputs": [],
141163
"source": [
142164
"class KDNet(nn.Module):\n",
143-
" def __init__(self):\n",
165+
" def __init__(self, k = 16):\n",
144166
" super(KDNet, self).__init__()\n",
145167
" self.conv1 = nn.Conv1d(3,8 * 3,1,1)\n",
146-
" \n",
168+
" self.conv2 = nn.Conv1d(8,32 * 3,1,1)\n",
169+
" self.conv3 = nn.Conv1d(32,64 * 3,1,1)\n",
170+
" self.conv4 = nn.Conv1d(64,64 * 3,1,1)\n",
171+
" self.conv5 = nn.Conv1d(64,64 * 3,1,1)\n",
172+
" self.conv6 = nn.Conv1d(64,128 * 3,1,1)\n",
173+
" self.conv7 = nn.Conv1d(128,256 * 3,1,1)\n",
174+
" self.conv8 = nn.Conv1d(256,512 * 3,1,1)\n",
175+
" self.conv9 = nn.Conv1d(512,512 * 3,1,1)\n",
176+
" self.conv10 = nn.Conv1d(512,512 * 3,1,1)\n",
177+
" self.conv11 = nn.Conv1d(512,1024 * 3,1,1) \n",
178+
" self.fc = nn.Linear(1024, k)\n",
147179
"\n",
148180
" def forward(self, x, c):\n",
149-
" x1 = self.conv1(x)\n",
150-
" #x1 = x1.view(-1, 3, 8, 2048)\n",
151-
" #sel = c[-1]\n",
152-
" \n",
153-
" #x1 = torch.index_select(x1, dim = 1, index = sel)\n",
181+
" def kdconv(x, dim, featdim, sel, conv):\n",
182+
" x = F.relu(conv(x))\n",
183+
" x = x.view(-1, featdim, 3, dim)\n",
184+
" x = x.view(-1, featdim, 3 * dim)\n",
185+
" sel = Variable(sel + (torch.arange(0,dim) * 3).long())\n",
186+
" if x.is_cuda:\n",
187+
" sel = sel.cuda() \n",
188+
" x = torch.index_select(x, dim = 2, index = sel)\n",
189+
" x = x.view(-1, featdim, dim/2, 2)\n",
190+
" x = torch.squeeze(torch.max(x, dim = -1)[0], 3)\n",
191+
" return x \n",
154192
" \n",
155-
" return x1\n",
193+
" x1 = kdconv(x, 2048, 8, c[-1], self.conv1)\n",
194+
" x2 = kdconv(x1, 1024, 32, c[-2], self.conv2)\n",
195+
" x3 = kdconv(x2, 512, 64, c[-3], self.conv3)\n",
196+
" x4 = kdconv(x3, 256, 64, c[-4], self.conv4)\n",
197+
" x5 = kdconv(x4, 128, 64, c[-5], self.conv5)\n",
198+
" x6 = kdconv(x5, 64, 128, c[-6], self.conv6)\n",
199+
" x7 = kdconv(x6, 32, 256, c[-7], self.conv7)\n",
200+
" x8 = kdconv(x7, 16, 512, c[-8], self.conv8)\n",
201+
" x9 = kdconv(x8, 8, 512, c[-9], self.conv9)\n",
202+
" x10 = kdconv(x9, 4, 512, c[-10], self.conv10)\n",
203+
" x11 = kdconv(x10, 2, 1024, c[-11], self.conv11)\n",
204+
" x11 = x11.view(-1,1024)\n",
205+
" out = F.log_softmax(self.fc(x11))\n",
206+
" return out\n",
156207
" \n",
157208
"net = KDNet()"
158209
]
159210
},
160211
{
161212
"cell_type": "code",
162-
"execution_count": 166,
213+
"execution_count": 175,
163214
"metadata": {
164215
"collapsed": false
165216
},
@@ -170,42 +221,59 @@
170221
},
171222
{
172223
"cell_type": "code",
173-
"execution_count": null,
224+
"execution_count": 176,
174225
"metadata": {
175-
"collapsed": true
226+
"collapsed": false
227+
},
228+
"outputs": [
229+
{
230+
"data": {
231+
"text/plain": [
232+
"torch.Size([1, 3, 2048])"
233+
]
234+
},
235+
"execution_count": 176,
236+
"metadata": {},
237+
"output_type": "execute_result"
238+
}
239+
],
240+
"source": [
241+
"points_v.size()"
242+
]
243+
},
244+
{
245+
"cell_type": "code",
246+
"execution_count": 177,
247+
"metadata": {
248+
"collapsed": false
176249
},
177250
"outputs": [],
178-
"source": []
251+
"source": [
252+
"torch.sum(x).backward()"
253+
]
179254
},
180255
{
181256
"cell_type": "code",
182-
"execution_count": 167,
257+
"execution_count": 178,
183258
"metadata": {
184259
"collapsed": false
185260
},
186261
"outputs": [
187262
{
188263
"data": {
189264
"text/plain": [
190-
"Variable containing:\n",
191-
"( 0 ,.,.) = \n",
192-
" -1.0363e+00 -1.0316e+00 -1.0316e+00 ... -1.3211e-02 -1.3213e-02 -1.3207e-02\n",
193-
" 3.9648e-01 3.9941e-01 3.9941e-01 ... -7.5476e-01 -7.5476e-01 -7.5477e-01\n",
194-
" 3.3743e-01 3.2779e-01 3.2779e-01 ... -2.3070e-01 -2.3070e-01 -2.3070e-01\n",
195-
" ... ⋱ ... \n",
196-
" -8.4430e-01 -8.3342e-01 -8.3342e-01 ... -1.9406e-01 -1.9406e-01 -1.9405e-01\n",
197-
" 2.6654e-02 2.4352e-02 2.4351e-02 ... 1.4608e-01 1.4608e-01 1.4608e-01\n",
198-
" -7.8591e-01 -7.9823e-01 -7.9823e-01 ... 2.9584e-02 2.9581e-02 2.9582e-02\n",
199-
"[torch.FloatTensor of size 1x24x2048]"
265+
"\n",
266+
" 0\n",
267+
"[torch.LongTensor of size 1]"
200268
]
201269
},
202-
"execution_count": 167,
270+
"execution_count": 178,
203271
"metadata": {},
204272
"output_type": "execute_result"
205273
}
206274
],
207275
"source": [
208-
"x"
276+
"class_label"
209277
]
210278
},
211279
{

train.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from datasets import PartDataset
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.autograd import Variable
7+
import torch.optim as optim
8+
9+
10+
class KDNet(nn.Module):
11+
def __init__(self, k = 16):
12+
super(KDNet, self).__init__()
13+
self.conv1 = nn.Conv1d(3,8 * 3,1,1)
14+
self.conv2 = nn.Conv1d(8,32 * 3,1,1)
15+
self.conv3 = nn.Conv1d(32,64 * 3,1,1)
16+
self.conv4 = nn.Conv1d(64,64 * 3,1,1)
17+
self.conv5 = nn.Conv1d(64,64 * 3,1,1)
18+
self.conv6 = nn.Conv1d(64,128 * 3,1,1)
19+
self.conv7 = nn.Conv1d(128,256 * 3,1,1)
20+
self.conv8 = nn.Conv1d(256,512 * 3,1,1)
21+
self.conv9 = nn.Conv1d(512,512 * 3,1,1)
22+
self.conv10 = nn.Conv1d(512,512 * 3,1,1)
23+
self.conv11 = nn.Conv1d(512,1024 * 3,1,1)
24+
self.fc = nn.Linear(1024, k)
25+
26+
def forward(self, x, c):
27+
def kdconv(x, dim, featdim, sel, conv):
28+
x = F.relu(conv(x))
29+
x = x.view(-1, featdim, 3, dim)
30+
x = x.view(-1, featdim, 3 * dim)
31+
sel = Variable(sel + (torch.arange(0,dim) * 3).long())
32+
if x.is_cuda:
33+
sel = sel.cuda()
34+
x = torch.index_select(x, dim = 2, index = sel)
35+
x = x.view(-1, featdim, dim/2, 2)
36+
x = torch.squeeze(torch.max(x, dim = -1)[0], 3)
37+
return x
38+
39+
x1 = kdconv(x, 2048, 8, c[-1], self.conv1)
40+
x2 = kdconv(x1, 1024, 32, c[-2], self.conv2)
41+
x3 = kdconv(x2, 512, 64, c[-3], self.conv3)
42+
x4 = kdconv(x3, 256, 64, c[-4], self.conv4)
43+
x5 = kdconv(x4, 128, 64, c[-5], self.conv5)
44+
x6 = kdconv(x5, 64, 128, c[-6], self.conv6)
45+
x7 = kdconv(x6, 32, 256, c[-7], self.conv7)
46+
x8 = kdconv(x7, 16, 512, c[-8], self.conv8)
47+
x9 = kdconv(x8, 8, 512, c[-9], self.conv9)
48+
x10 = kdconv(x9, 4, 512, c[-10], self.conv10)
49+
x11 = kdconv(x10, 2, 1024, c[-11], self.conv11)
50+
x11 = x11.view(-1,1024)
51+
out = F.log_softmax(self.fc(x11))
52+
return out
53+
54+
def split_ps(point_set):
55+
#print point_set.size()
56+
num_points = point_set.size()[0]/2
57+
diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
58+
dim = torch.max(diff, dim = 1)[1][0,0]
59+
cut = torch.median(point_set[:,dim])[0][0]
60+
left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
61+
right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
62+
middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))
63+
64+
#if torch.numel(left_idx) > 0:
65+
# left_idx = left_idx[:,0]
66+
#if torch.numel(right_idx) > 0:
67+
# right_idx = right_idx[:,0]
68+
#if torch.numel(middle_idx) > 0:
69+
# middle_idx = middle_idx[:,0]
70+
71+
if torch.numel(left_idx) < num_points:
72+
left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
73+
if torch.numel(right_idx) < num_points:
74+
right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)
75+
76+
left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
77+
right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
78+
return left_ps, right_ps, dim
79+
80+
81+
82+
83+
d = PartDataset(root = '../unsupervised3d/shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
84+
85+
print(len(d.classes))
86+
87+
levels = (np.log(2048)/np.log(2)).astype(int)
88+
89+
cutdim = torch.zeros((levels)).long()
90+
91+
net = KDNet()
92+
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
93+
94+
95+
for i in range(1000):
96+
97+
point_set, class_label = d[i]
98+
target = Variable(class_label)
99+
tree = [[] for i in range(levels + 1)]
100+
cutdim = [[] for i in range(levels)]
101+
tree[0].append(point_set)
102+
for level in range(levels):
103+
for item in tree[level]:
104+
left_ps, right_ps, dim = split_ps(item)
105+
tree[level+1].append(left_ps)
106+
tree[level+1].append(right_ps)
107+
cutdim[level].append(dim)
108+
cutdim[level].append(dim)
109+
cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]
110+
points = torch.stack(tree[-1])
111+
112+
113+
points_v = Variable(torch.unsqueeze(torch.squeeze(points), 0)).transpose(2,1)
114+
115+
116+
optimizer.zero_grad()
117+
pred = net(points_v, cutdim)
118+
119+
loss = F.nll_loss(pred, target)
120+
loss.backward()
121+
optimizer.step()
122+
123+
print(loss)
124+

0 commit comments

Comments
 (0)