|
35 | 35 | "d = PartDataset(root = '../unsupervised3d/shapenetcore_partanno_segmentation_benchmark_v0', classification = True)"
|
36 | 36 | ]
|
37 | 37 | },
|
| 38 | + { |
| 39 | + "cell_type": "code", |
| 40 | + "execution_count": 151, |
| 41 | + "metadata": { |
| 42 | + "collapsed": false |
| 43 | + }, |
| 44 | + "outputs": [ |
| 45 | + { |
| 46 | + "data": { |
| 47 | + "text/plain": [ |
| 48 | + "16" |
| 49 | + ] |
| 50 | + }, |
| 51 | + "execution_count": 151, |
| 52 | + "metadata": {}, |
| 53 | + "output_type": "execute_result" |
| 54 | + } |
| 55 | + ], |
| 56 | + "source": [ |
| 57 | + "len(d.classes)" |
| 58 | + ] |
| 59 | + }, |
38 | 60 | {
|
39 | 61 | "cell_type": "code",
|
40 | 62 | "execution_count": 3,
|
|
60 | 82 | },
|
61 | 83 | {
|
62 | 84 | "cell_type": "code",
|
63 |
| - "execution_count": 95, |
| 85 | + "execution_count": 5, |
64 | 86 | "metadata": {
|
65 | 87 | "collapsed": false
|
66 | 88 | },
|
|
95 | 117 | },
|
96 | 118 | {
|
97 | 119 | "cell_type": "code",
|
98 |
| - "execution_count": 145, |
| 120 | + "execution_count": 34, |
99 | 121 | "metadata": {
|
100 | 122 | "collapsed": false
|
101 | 123 | },
|
|
104 | 126 | "name": "stdout",
|
105 | 127 | "output_type": "stream",
|
106 | 128 | "text": [
|
107 |
| - "CPU times: user 1.23 s, sys: 12 ms, total: 1.24 s\n", |
108 |
| - "Wall time: 1.24 s\n" |
| 129 | + "CPU times: user 1.28 s, sys: 0 ns, total: 1.28 s\n", |
| 130 | + "Wall time: 1.28 s\n" |
109 | 131 | ]
|
110 | 132 | }
|
111 | 133 | ],
|
|
124 | 146 | " tree[level+1].append(right_ps)\n",
|
125 | 147 | " cutdim[level].append(dim) \n",
|
126 | 148 | " cutdim[level].append(dim) \n",
|
127 |
| - " cutdim = [Variable(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n", |
| 149 | + " cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n", |
128 | 150 | " points = torch.stack(tree[-1])\n",
|
129 | 151 | " \n",
|
130 | 152 | " \n",
|
|
133 | 155 | },
|
134 | 156 | {
|
135 | 157 | "cell_type": "code",
|
136 |
| - "execution_count": 165, |
| 158 | + "execution_count": 174, |
137 | 159 | "metadata": {
|
138 | 160 | "collapsed": false
|
139 | 161 | },
|
140 | 162 | "outputs": [],
|
141 | 163 | "source": [
|
142 | 164 | "class KDNet(nn.Module):\n",
|
143 |
| - " def __init__(self):\n", |
| 165 | + " def __init__(self, k = 16):\n", |
144 | 166 | " super(KDNet, self).__init__()\n",
|
145 | 167 | " self.conv1 = nn.Conv1d(3,8 * 3,1,1)\n",
|
146 |
| - " \n", |
| 168 | + " self.conv2 = nn.Conv1d(8,32 * 3,1,1)\n", |
| 169 | + " self.conv3 = nn.Conv1d(32,64 * 3,1,1)\n", |
| 170 | + " self.conv4 = nn.Conv1d(64,64 * 3,1,1)\n", |
| 171 | + " self.conv5 = nn.Conv1d(64,64 * 3,1,1)\n", |
| 172 | + " self.conv6 = nn.Conv1d(64,128 * 3,1,1)\n", |
| 173 | + " self.conv7 = nn.Conv1d(128,256 * 3,1,1)\n", |
| 174 | + " self.conv8 = nn.Conv1d(256,512 * 3,1,1)\n", |
| 175 | + " self.conv9 = nn.Conv1d(512,512 * 3,1,1)\n", |
| 176 | + " self.conv10 = nn.Conv1d(512,512 * 3,1,1)\n", |
| 177 | + " self.conv11 = nn.Conv1d(512,1024 * 3,1,1) \n", |
| 178 | + " self.fc = nn.Linear(1024, k)\n", |
147 | 179 | "\n",
|
148 | 180 | " def forward(self, x, c):\n",
|
149 |
| - " x1 = self.conv1(x)\n", |
150 |
| - " #x1 = x1.view(-1, 3, 8, 2048)\n", |
151 |
| - " #sel = c[-1]\n", |
152 |
| - " \n", |
153 |
| - " #x1 = torch.index_select(x1, dim = 1, index = sel)\n", |
| 181 | + " def kdconv(x, dim, featdim, sel, conv):\n", |
| 182 | + " x = F.relu(conv(x))\n", |
| 183 | + " x = x.view(-1, featdim, 3, dim)\n", |
| 184 | + " x = x.view(-1, featdim, 3 * dim)\n", |
| 185 | + " sel = Variable(sel + (torch.arange(0,dim) * 3).long())\n", |
| 186 | + " if x.is_cuda:\n", |
| 187 | + " sel = sel.cuda() \n", |
| 188 | + " x = torch.index_select(x, dim = 2, index = sel)\n", |
| 189 | + " x = x.view(-1, featdim, dim/2, 2)\n", |
| 190 | + " x = torch.squeeze(torch.max(x, dim = -1)[0], 3)\n", |
| 191 | + " return x \n", |
154 | 192 | " \n",
|
155 |
| - " return x1\n", |
| 193 | + " x1 = kdconv(x, 2048, 8, c[-1], self.conv1)\n", |
| 194 | + " x2 = kdconv(x1, 1024, 32, c[-2], self.conv2)\n", |
| 195 | + " x3 = kdconv(x2, 512, 64, c[-3], self.conv3)\n", |
| 196 | + " x4 = kdconv(x3, 256, 64, c[-4], self.conv4)\n", |
| 197 | + " x5 = kdconv(x4, 128, 64, c[-5], self.conv5)\n", |
| 198 | + " x6 = kdconv(x5, 64, 128, c[-6], self.conv6)\n", |
| 199 | + " x7 = kdconv(x6, 32, 256, c[-7], self.conv7)\n", |
| 200 | + " x8 = kdconv(x7, 16, 512, c[-8], self.conv8)\n", |
| 201 | + " x9 = kdconv(x8, 8, 512, c[-9], self.conv9)\n", |
| 202 | + " x10 = kdconv(x9, 4, 512, c[-10], self.conv10)\n", |
| 203 | + " x11 = kdconv(x10, 2, 1024, c[-11], self.conv11)\n", |
| 204 | + " x11 = x11.view(-1,1024)\n", |
| 205 | + " out = F.log_softmax(self.fc(x11))\n", |
| 206 | + " return out\n", |
156 | 207 | " \n",
|
157 | 208 | "net = KDNet()"
|
158 | 209 | ]
|
159 | 210 | },
|
160 | 211 | {
|
161 | 212 | "cell_type": "code",
|
162 |
| - "execution_count": 166, |
| 213 | + "execution_count": 175, |
163 | 214 | "metadata": {
|
164 | 215 | "collapsed": false
|
165 | 216 | },
|
|
170 | 221 | },
|
171 | 222 | {
|
172 | 223 | "cell_type": "code",
|
173 |
| - "execution_count": null, |
| 224 | + "execution_count": 176, |
174 | 225 | "metadata": {
|
175 |
| - "collapsed": true |
| 226 | + "collapsed": false |
| 227 | + }, |
| 228 | + "outputs": [ |
| 229 | + { |
| 230 | + "data": { |
| 231 | + "text/plain": [ |
| 232 | + "torch.Size([1, 3, 2048])" |
| 233 | + ] |
| 234 | + }, |
| 235 | + "execution_count": 176, |
| 236 | + "metadata": {}, |
| 237 | + "output_type": "execute_result" |
| 238 | + } |
| 239 | + ], |
| 240 | + "source": [ |
| 241 | + "points_v.size()" |
| 242 | + ] |
| 243 | + }, |
| 244 | + { |
| 245 | + "cell_type": "code", |
| 246 | + "execution_count": 177, |
| 247 | + "metadata": { |
| 248 | + "collapsed": false |
176 | 249 | },
|
177 | 250 | "outputs": [],
|
178 |
| - "source": [] |
| 251 | + "source": [ |
| 252 | + "torch.sum(x).backward()" |
| 253 | + ] |
179 | 254 | },
|
180 | 255 | {
|
181 | 256 | "cell_type": "code",
|
182 |
| - "execution_count": 167, |
| 257 | + "execution_count": 178, |
183 | 258 | "metadata": {
|
184 | 259 | "collapsed": false
|
185 | 260 | },
|
186 | 261 | "outputs": [
|
187 | 262 | {
|
188 | 263 | "data": {
|
189 | 264 | "text/plain": [
|
190 |
| - "Variable containing:\n", |
191 |
| - "( 0 ,.,.) = \n", |
192 |
| - " -1.0363e+00 -1.0316e+00 -1.0316e+00 ... -1.3211e-02 -1.3213e-02 -1.3207e-02\n", |
193 |
| - " 3.9648e-01 3.9941e-01 3.9941e-01 ... -7.5476e-01 -7.5476e-01 -7.5477e-01\n", |
194 |
| - " 3.3743e-01 3.2779e-01 3.2779e-01 ... -2.3070e-01 -2.3070e-01 -2.3070e-01\n", |
195 |
| - " ... ⋱ ... \n", |
196 |
| - " -8.4430e-01 -8.3342e-01 -8.3342e-01 ... -1.9406e-01 -1.9406e-01 -1.9405e-01\n", |
197 |
| - " 2.6654e-02 2.4352e-02 2.4351e-02 ... 1.4608e-01 1.4608e-01 1.4608e-01\n", |
198 |
| - " -7.8591e-01 -7.9823e-01 -7.9823e-01 ... 2.9584e-02 2.9581e-02 2.9582e-02\n", |
199 |
| - "[torch.FloatTensor of size 1x24x2048]" |
| 265 | + "\n", |
| 266 | + " 0\n", |
| 267 | + "[torch.LongTensor of size 1]" |
200 | 268 | ]
|
201 | 269 | },
|
202 |
| - "execution_count": 167, |
| 270 | + "execution_count": 178, |
203 | 271 | "metadata": {},
|
204 | 272 | "output_type": "execute_result"
|
205 | 273 | }
|
206 | 274 | ],
|
207 | 275 | "source": [
|
208 |
| - "x" |
| 276 | + "class_label" |
209 | 277 | ]
|
210 | 278 | },
|
211 | 279 | {
|
|
0 commit comments