@@ -80,12 +80,13 @@ def __getitem__(self, index):
80
80
point_set = point_set / dist
81
81
82
82
83
- choice = np .random .choice (len (seg ), self .npoints , replace = True )
83
+ # choice = np.random.choice(len(seg), self.npoints, replace=True)
84
84
#resample
85
- point_set = point_set [choice , :]
86
- point_set = point_set + 1e-5 * np .random .rand (* point_set .shape )
85
+ #point_set = point_set[choice, :]
86
+ #point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
87
+ #print(point_set.shape)
87
88
88
- seg = seg [choice ]
89
+ # seg = seg[choice]
89
90
point_set = torch .from_numpy (point_set .astype (np .float32 ))
90
91
seg = torch .from_numpy (seg .astype (np .int64 ))
91
92
cls = torch .from_numpy (np .array ([cls ]).astype (np .int64 ))
@@ -98,14 +99,105 @@ def __len__(self):
98
99
return len (self .datapath )
99
100
100
101
102
+ class PartDatasetSVM (data .Dataset ):
103
+ def __init__ (self , root , npoints = 2048 , classification = False , class_choice = None , train = True ):
104
+ self .npoints = npoints
105
+ self .root = root
106
+ self .catfile = os .path .join (self .root , 'synsetoffset2category.txt' )
107
+ self .cat = {}
108
+
109
+ self .classification = classification
110
+
111
+ with open (self .catfile , 'r' ) as f :
112
+ for line in f :
113
+ ls = line .strip ().split ()
114
+ #print(ls)
115
+ self .cat [ls [0 ]] = ls [1 ]
116
+ #print(self.cat)
117
+ if not class_choice is None :
118
+ self .cat = {k :v for k ,v in self .cat .items () if k in class_choice }
119
+
120
+ self .meta = {}
121
+ for item in self .cat :
122
+ #print('category', item)
123
+ self .meta [item ] = []
124
+ dir_point = os .path .join (self .root , self .cat [item ])
125
+ #dir_point = os.path.join(self.root, self.cat[item], 'points')
126
+ #dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
127
+ #print(dir_point, dir_seg)
128
+ fns = sorted (os .listdir (dir_point ))
129
+ if train :
130
+ fns = fns [:int (len (fns ) * 0.9 )]
131
+ else :
132
+ fns = fns [int (len (fns ) * 0.9 ):]
133
+
134
+ #print(os.path.basename(fns))
135
+ for fn in fns :
136
+ token = (os .path .splitext (os .path .basename (fn ))[0 ])
137
+ pth = os .path .join (dir_point , token + '.npy' )
138
+ self .meta [item ].append ((pth , pth ))
139
+
140
+ self .datapath = []
141
+ for item in self .cat :
142
+ for fn in self .meta [item ]:
143
+ self .datapath .append ((item , fn [0 ], fn [1 ]))
144
+
145
+ self .classes = dict (zip (self .cat , range (len (self .cat ))))
146
+ print (self .classes )
147
+ self .num_seg_classes = 0
148
+
149
+
150
+ def __getitem__ (self , index ):
151
+ fn = self .datapath [index ]
152
+ cls = self .classes [self .datapath [index ][0 ]]
153
+ cluster_data = np .load (fn [1 ])
154
+ #print (cluster_data.dtype)
155
+ point_set = cluster_data ['delta' ]
156
+ seg = cluster_data ['type_j' ]
157
+ #print(point_set.shape, seg.shape)
158
+
159
+ #point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0)
160
+ dist = np .max (np .sqrt (np .sum (point_set ** 2 , axis = 1 )),0 )
161
+ dist = np .expand_dims (np .expand_dims (dist , 0 ), 1 )
162
+ point_set = point_set / dist
163
+
164
+
165
+
166
+ dist = np .sum (point_set ** 2 ,1 )
167
+ bins = np .arange (0 ,1 + 1e-4 ,1 / 30.0 )
168
+ feat1 = np .histogram (dist [seg == 1 ], bins , density = True )[0 ]
169
+ feat2 = np .histogram (dist [seg == 2 ], bins , density = True )[0 ]
170
+
171
+ feat = np .concatenate ([feat1 , feat2 ])
172
+
173
+ #from IPython import embed; embed()
174
+
175
+ #choice = np.random.choice(len(seg), self.npoints, replace=True)
176
+ #resample
177
+ #point_set = point_set[choice, :]
178
+ #point_set = point_set + 1e-5 * np.random.rand(*point_set.shape)
179
+ #print(point_set.shape)
180
+
181
+ #seg = seg[choice]
182
+ #point_set = torch.from_numpy(point_set.astype(np.float32))
183
+ #seg = torch.from_numpy(seg.astype(np.int64))
184
+ #cls = torch.from_numpy(np.array([cls]).astype(np.int64))
185
+
186
+ return feat , cls
187
+
188
+ def __len__ (self ):
189
+ return len (self .datapath )
190
+
191
+
192
+
101
193
if __name__ == '__main__' :
102
194
print ('test' )
103
- d = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0 ' , class_choice = [ 'Chair' ] )
195
+ d = PartDataset (root = 'mg ' , classification = True )
104
196
print (len (d ))
105
197
ps , seg = d [0 ]
106
198
print (ps .size (), ps .type (), seg .size (),seg .type ())
107
199
108
- d = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0 ' , classification = True )
200
+ d = PartDatasetSVM (root = 'mg ' , classification = True )
109
201
print (len (d ))
110
202
ps , cls = d [0 ]
111
- print (ps .size () , ps .type () , cls . size (), cls . type () )
203
+ print (ps .shape , ps .dtype , cls )
0 commit comments