Skip to content

Commit 146ca4b

Browse files
committed
add experiments
1 parent 1ff497d commit 146ca4b

File tree

3 files changed

+158
-8
lines changed

3 files changed

+158
-8
lines changed

dataset_metallic_glass.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ def __getitem__(self, index):
8080
point_set = point_set/dist
8181

8282

83-
choice = np.random.choice(len(seg), self.npoints, replace=True)
83+
#choice = np.random.choice(len(seg), self.npoints, replace=True)
8484
#resample
85-
point_set = point_set[choice, :]
86-
point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
85+
#point_set = point_set[choice, :]
86+
#point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
87+
#print(point_set.shape)
8788

88-
seg = seg[choice]
89+
#seg = seg[choice]
8990
point_set = torch.from_numpy(point_set.astype(np.float32))
9091
seg = torch.from_numpy(seg.astype(np.int64))
9192
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
@@ -98,14 +99,105 @@ def __len__(self):
9899
return len(self.datapath)
99100

100101

102+
class PartDatasetSVM(data.Dataset):
103+
def __init__(self, root, npoints = 2048, classification = False, class_choice = None, train = True):
104+
self.npoints = npoints
105+
self.root = root
106+
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
107+
self.cat = {}
108+
109+
self.classification = classification
110+
111+
with open(self.catfile, 'r') as f:
112+
for line in f:
113+
ls = line.strip().split()
114+
#print(ls)
115+
self.cat[ls[0]] = ls[1]
116+
#print(self.cat)
117+
if not class_choice is None:
118+
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
119+
120+
self.meta = {}
121+
for item in self.cat:
122+
#print('category', item)
123+
self.meta[item] = []
124+
dir_point = os.path.join(self.root, self.cat[item])
125+
#dir_point = os.path.join(self.root, self.cat[item], 'points')
126+
#dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
127+
#print(dir_point, dir_seg)
128+
fns = sorted(os.listdir(dir_point))
129+
if train:
130+
fns = fns[:int(len(fns) * 0.9)]
131+
else:
132+
fns = fns[int(len(fns) * 0.9):]
133+
134+
#print(os.path.basename(fns))
135+
for fn in fns:
136+
token = (os.path.splitext(os.path.basename(fn))[0])
137+
pth = os.path.join(dir_point, token + '.npy')
138+
self.meta[item].append((pth, pth))
139+
140+
self.datapath = []
141+
for item in self.cat:
142+
for fn in self.meta[item]:
143+
self.datapath.append((item, fn[0], fn[1]))
144+
145+
self.classes = dict(zip(self.cat, range(len(self.cat))))
146+
print(self.classes)
147+
self.num_seg_classes = 0
148+
149+
150+
def __getitem__(self, index):
151+
fn = self.datapath[index]
152+
cls = self.classes[self.datapath[index][0]]
153+
cluster_data = np.load(fn[1])
154+
#print (cluster_data.dtype)
155+
point_set = cluster_data['delta']
156+
seg = cluster_data['type_j']
157+
#print(point_set.shape, seg.shape)
158+
159+
#point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0)
160+
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
161+
dist = np.expand_dims(np.expand_dims(dist, 0), 1)
162+
point_set = point_set/dist
163+
164+
165+
166+
dist = np.sum(point_set**2,1)
167+
bins = np.arange(0,1 + 1e-4,1/30.0)
168+
feat1 = np.histogram(dist[seg == 1], bins, density = True)[0]
169+
feat2 = np.histogram(dist[seg == 2], bins, density = True)[0]
170+
171+
feat = np.concatenate([feat1, feat2])
172+
173+
#from IPython import embed; embed()
174+
175+
#choice = np.random.choice(len(seg), self.npoints, replace=True)
176+
#resample
177+
#point_set = point_set[choice, :]
178+
#point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
179+
#print(point_set.shape)
180+
181+
#seg = seg[choice]
182+
#point_set = torch.from_numpy(point_set.astype(np.float32))
183+
#seg = torch.from_numpy(seg.astype(np.int64))
184+
#cls = torch.from_numpy(np.array([cls]).astype(np.int64))
185+
186+
return feat, cls
187+
188+
def __len__(self):
189+
return len(self.datapath)
190+
191+
192+
101193
if __name__ == '__main__':
102194
print('test')
103-
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'])
195+
d = PartDataset(root = 'mg', classification = True)
104196
print(len(d))
105197
ps, seg = d[0]
106198
print(ps.size(), ps.type(), seg.size(),seg.type())
107199

108-
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
200+
d = PartDatasetSVM(root = 'mg', classification = True)
109201
print(len(d))
110202
ps, cls = d[0]
111-
print(ps.size(), ps.type(), cls.size(),cls.type())
203+
print(ps.shape, ps.dtype, cls)

svm.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from dataset_metallic_glass import PartDataset, PartDatasetSVM
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.autograd import Variable
7+
import torch.optim as optim
8+
9+
10+
d = PartDatasetSVM(root = 'mg', classification = False)
11+
dt = PartDatasetSVM(root = 'mg', classification = False, train = False)
12+
13+
14+
l = len(d)
15+
print(len(d.classes), l)
16+
17+
18+
lt = len(dt)
19+
print(lt)
20+
21+
train_set = []
22+
train_label = []
23+
for i in range(l):
24+
idx = i
25+
sample, label = d[idx]
26+
train_set.append(sample)
27+
train_label.append(label)
28+
if i%100 == 0:
29+
print i
30+
31+
test_set = []
32+
test_label = []
33+
for i in range(lt):
34+
idx = i
35+
sample, label = dt[idx]
36+
test_set.append(sample)
37+
test_label.append(label)
38+
if i%100 == 0:
39+
print i
40+
41+
42+
train_set = np.array(train_set)
43+
train_label = np.array(train_label)
44+
test_set = np.array(test_set)
45+
test_label = np.array(test_label)
46+
47+
print(train_set.shape, train_label.shape)
48+
49+
from sklearn.svm import SVC
50+
clf = SVC()
51+
52+
clf.fit(train_set, train_label)
53+
54+
accuracy = np.sum(clf.predict(train_set) == train_label)/float(l)
55+
print(accuracy)
56+
57+
accuracy_test = np.sum(clf.predict(test_set) == test_label)/float(lt)
58+
print(accuracy_test)

train_MG2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,5 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
226226
print('batch: %d, loss: %f, correct %d/%d' %( it, np.mean(losses), correct, bt))
227227

228228
if it % 1000 == 0:
229-
torch.save(net.state_dict(), 'mg2_model_cuda_%d.pth' % (it))
229+
torch.save(net.state_dict(), 'mg2_model_cuda_new_%d.pth' % (it))
230230

0 commit comments

Comments
 (0)