Skip to content

Commit d943078

Browse files
identity code initial commit
1 parent 055a62c commit d943078

File tree

9 files changed

+887
-1
lines changed

9 files changed

+887
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,5 @@ dmypy.json
136136
# files
137137
**/*.pdf
138138
**/*.svg*
139-
139+
**/*.png
140+
**/*.ppm

identity/PyTorch/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
model name: 'identity_net'
2+
batch size: 64
3+
lr: 0.001

identity/PyTorch/data_pipe_line.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from torch.utils.data import Dataset, DataLoader
2+
from torchvision.datasets import ImageFolder
3+
from torchvision.transforms import transforms
4+
import torch
5+
from model import Identity
6+
from torchvision.utils import save_image
7+
from general import incremental_filename, check_path
8+
import gc
9+
from general import set_logger
10+
from tqdm import tqdm
11+
import numpy as np
12+
logger = set_logger(__name__, mode='a')
13+
14+
15+
# ~~~~~~~~~~~~~~~~~~~~~ image tranforms ~~~~~~~~~~~~~~~~~~~~~ #
16+
17+
train_t = transforms.Compose([
18+
transforms.RandomRotation(15),
19+
transforms.RandomResizedCrop(224),
20+
transforms.RandomHorizontalFlip(),
21+
transforms.Resize((220, 220)),
22+
transforms.ToTensor(),
23+
transforms.Normalize([0.3417, 0.3126, 0.3216],
24+
[0.168, 0.1678, 0.178])
25+
])
26+
27+
28+
trans = transforms.Compose([
29+
transforms.Resize((220, 220)),
30+
transforms.ToTensor(),
31+
transforms.Normalize([0.3417, 0.3126, 0.3216],
32+
[0.168, 0.1678, 0.178])
33+
])
34+
35+
# ~~~~~~~~~~~~~~~~~~~~~ helper functions ~~~~~~~~~~~~~~~~~~~~~ #
36+
37+
38+
def create_one_hot(n, idx):
39+
one_hot = torch.zeros(idx.shape[0], n)
40+
one_hot.scatter_(1, idx.unsqueeze(1), 1)
41+
return one_hot
42+
43+
44+
class TrafficDataSet(Dataset):
45+
46+
"""Returns triplet images (anchor, positive, and Negative) along with the class
47+
in [ai,pi,ni,cl] format.
48+
Args:
49+
dir:Dataset directory. this should be readable by
50+
torchvision.datasets.ImageFolder instance.
51+
model: torch.nn.Module instance.
52+
"""
53+
54+
def __init__(self, dir, model):
55+
super().__init__()
56+
self.trans = transforms.Compose([
57+
transforms.Resize((220, 220)),
58+
transforms.ToTensor(),
59+
transforms.Normalize([0.3417, 0.3126, 0.3216],
60+
[0.168, 0.1678, 0.178])
61+
])
62+
self.train_t = transforms.Compose([
63+
transforms.RandomRotation(15),
64+
transforms.RandomResizedCrop(224),
65+
transforms.RandomHorizontalFlip(),
66+
transforms.Resize((220, 220)),
67+
transforms.ToTensor(),
68+
transforms.Normalize([0.3417, 0.3126, 0.3216],
69+
[0.168, 0.1678, 0.178])
70+
])
71+
72+
self.data = ImageFolder(dir, self.trans)
73+
self.class_len = len(self.data.classes)
74+
self.loader = DataLoader(self.data, 128, True)
75+
self.dataset = []
76+
self.update_dataset_random_choice()
77+
78+
def update_dataset_using_norm(self, model):
79+
dataset = []
80+
model.eval()
81+
with torch.no_grad():
82+
logger.info('started updating dataset')
83+
for img, label in self.loader:
84+
train_norm = model(img)
85+
p_hot = create_one_hot(self.class_len, label)
86+
n_hot = torch.abs(p_hot - 1)
87+
88+
for cl in range(len(label)):
89+
pi_entries = (label == cl).type(torch.uint8)
90+
91+
if pi_entries.sum() > 1:
92+
c_p_hot = p_hot[:, cl]
93+
c_n_hot = n_hot[:, cl]
94+
pi_entries = torch.nonzero(
95+
pi_entries, as_tuple=True)
96+
97+
for ai in pi_entries[0]:
98+
norm_matrix = torch.sqrt(
99+
torch.sum(
100+
(train_norm[ai]-train_norm)**2, dim=-1)
101+
).squeeze()
102+
103+
pi_score, pi = torch.max(
104+
norm_matrix*c_p_hot, dim=-1)
105+
assert torch.numel(
106+
pi) == 1, (f'pi size is {torch.numel(pi)}.' +
107+
' it should be 1')
108+
109+
norm_matrix_n = torch.where(
110+
norm_matrix > pi_score, norm_matrix, torch.tensor(0.0))
111+
112+
ni = torch.argmin(norm_matrix_n*c_n_hot, dim=-1)
113+
assert torch.numel(
114+
ni) == 1, (f'ni size is {torch.numel(ni)}.' +
115+
' it should be 1')
116+
117+
dataset.append([img[ai], img[pi], img[ni], cl])
118+
logger.info(
119+
f'class:{cl} |anchor:{ai} |positive:{pi}' +
120+
f'|negative: {ni}')
121+
logger.info(f"dataset size:{len(dataset)}")
122+
123+
self.dataset = dataset
124+
del dataset
125+
126+
def update_dataset_random_choice(self):
127+
dataset = []
128+
for img, label in tqdm(self.loader):
129+
for cl in range(len(label)):
130+
pi_entries = np.array((label == cl).nonzero(as_tuple=True)[0])
131+
ni_entries = np.array((label != cl).nonzero(as_tuple=True)[0])
132+
133+
if pi_entries.sum() > 1:
134+
135+
for ai in pi_entries:
136+
pi = np.random.choice(pi_entries, 1)[0]
137+
138+
assert pi.size == 1, (f'pi size is {torch.numel(pi)}.' +
139+
' it should be 1')
140+
141+
ni = np.random.choice(ni_entries, 1)[0]
142+
143+
assert ni.size == 1, (f'ni size is {torch.numel(ni)}.' +
144+
' it should be 1')
145+
146+
# logger.info(
147+
# f'class:{cl} |anchor:{ai} |positive:{pi}' +
148+
# f'|negative: {ni}')
149+
dataset.append([img[ai], img[pi], img[ni], cl])
150+
logger.info(f"dataset size:{len(dataset)}")
151+
152+
self.dataset = dataset
153+
del dataset
154+
155+
def __getitem__(self, idx):
156+
157+
ai = self.dataset[idx][0]
158+
pi = self.dataset[idx][1]
159+
ni = self.dataset[idx][2]
160+
cl = self.dataset[idx][3]
161+
# ai = self.ai_transforms(ai)
162+
# pi = self.pi_transforms(pi)
163+
# ni = self.ni_transforms(ni)
164+
return [ai, pi, ni, cl]
165+
166+
def __len__(self):
167+
return len(self.dataset)
168+
169+
170+
def get_hash_matrix(model, dir, device):
171+
172+
device = device
173+
check_path(dir)
174+
175+
data = ImageFolder(dir, transform=trans)
176+
loader = DataLoader(data, 128, False)
177+
178+
with torch.no_grad():
179+
for x, _ in loader:
180+
x = x.to(device)
181+
y_hat = model(x)
182+
183+
return y_hat
184+
185+
186+
if __name__ == "__main__":
187+
188+
dir = '/home/user/datasets/GTSRB'
189+
model = Identity()
190+
data = TrafficDataSet(dir, model)
191+
logger.info(
192+
'enter an index for saving the data point\n enter q to quit...\n')
193+
194+
key = input("index number: ")
195+
196+
while key != 'q':
197+
198+
assert key.isnumeric(), f'expecting an integer got {type(key)}'
199+
key = int(key)
200+
if key >= 0 and key < len(data):
201+
img = torch.stack(data[key][:3], dim=0)
202+
file_name = incremental_filename('identity/data_point', 'entry')
203+
save_image(img, file_name)
204+
logger.info(f'file saved as {file_name}\n')
205+
key = input('next index:')
206+
else:
207+
logger.info(f'entered value: {key} is out of range.' +
208+
f'current range is 0 to {len(data)-1}')
209+
key = input('next index:')
210+
211+
gc.collect()

0 commit comments

Comments
 (0)