Skip to content

Commit d021ab5

Browse files
coderzbxzhangbenxing
authored andcommitted
增加自有脚本
1 parent 4b53f1c commit d021ab5

File tree

3 files changed

+460
-0
lines changed

3 files changed

+460
-0
lines changed

SelfScripts/data_clr.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*-coding:utf-8-*-
2+
3+
from collections import namedtuple
4+
import argparse
5+
import os
6+
import cv2
7+
8+
Label = namedtuple(
9+
'Label', ['name', 'id', 'color'])
10+
11+
cityscapes_19_label = {
12+
Label("road" , 0, (128, 64, 128)),
13+
Label("sidewalk" , 1, (244, 35, 232)),
14+
Label("building" , 2, (70, 70, 70)),
15+
Label("wall" , 3, (102, 102, 156)),
16+
Label("fence" , 4, (190, 153, 153)),
17+
Label("pole" , 5, (153, 153, 153)),
18+
Label("traffic light" , 6, (250, 170, 30)),
19+
Label("traffic sign" , 7, (220, 220, 0)),
20+
Label("vegetation" , 8, (107, 142, 35)),
21+
Label("terrain" , 9, (152, 251, 152)),
22+
Label("sky" , 10, (70, 130, 180)),
23+
Label("person" , 11, (220, 20, 60)),
24+
Label("rider" , 12, (255, 0, 0)),
25+
Label("car" , 13, (0, 0, 142)),
26+
Label("truck" , 14, (0, 0, 70)),
27+
Label("bus" , 15, (0, 60, 100)),
28+
Label("train" , 16, (0, 80, 100)),
29+
Label("motorcycle" , 17, (0, 0, 230)),
30+
Label("bicycle" , 18, (119, 11, 32))
31+
}

SelfScripts/demo.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import caffe
2+
import numpy as np
3+
import cv2
4+
import os
5+
import time
6+
import argparse
7+
8+
9+
class ModelSegNetDemo:
10+
def __init__(self, model, weights, colours, gpu_id=3):
11+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
12+
caffe.set_mode_gpu()
13+
14+
self.weights = weights
15+
self.model = model
16+
self.colours = colours
17+
18+
self.net = caffe.Net(self.model,
19+
self.weights,
20+
caffe.TEST)
21+
22+
def do(self, image_data):
23+
24+
input_shape = self.net.blobs['data'].data.shape
25+
label_colours = cv2.imread(self.colours).astype(np.uint8)
26+
27+
start = time.time()
28+
29+
image = np.asarray(bytearray(image_data), dtype="uint8")
30+
origin_frame = cv2.imdecode(image, cv2.IMREAD_COLOR)
31+
32+
width = origin_frame.shape[1]
33+
height = origin_frame.shape[0]
34+
35+
frame = cv2.resize(origin_frame, (input_shape[3], input_shape[2]))
36+
input_image = frame.transpose((2, 0, 1))
37+
input_image = np.asarray([input_image])
38+
self.net.forward_all(data=input_image)
39+
40+
predict = self.net.blobs['conv6_interp'].data
41+
out_pred = np.resize(predict, (3, input_shape[2], input_shape[3]))
42+
out_pred = out_pred.transpose(1, 2, 0).astype(np.uint8)
43+
for j in range(0, 713):
44+
for k in range(0, 713):
45+
x = -1
46+
label = 0
47+
for i in range(0, 19):
48+
if predict[0][i][j][k] > x:
49+
x = predict[0][i][j][k]
50+
label = i
51+
out_pred[j][k][0] = out_pred[j][k][1] = out_pred[j][k][2] = label
52+
out_rgb = np.zeros(out_pred.shape, dtype=np.uint8)
53+
54+
cv2.LUT(out_pred, label_colours, out_rgb)
55+
rgb_frame = cv2.resize(out_rgb, (width, height), interpolation=cv2.INTER_NEAREST)
56+
57+
img_array = cv2.imencode('.png', rgb_frame)
58+
img_data = img_array[1]
59+
pred_data = img_data.tostring()
60+
61+
end = time.time()
62+
print('%30s' % 'Processed results in ', str((end - start) * 1000), 'ms\n')
63+
64+
return pred_data
65+
66+
67+
if __name__ == '__main__':
68+
69+
weights = ''
70+
model = ''
71+
colours = ''
72+
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument('--model', type=str, required=False)
75+
parser.add_argument('--weights', type=str, required=False)
76+
parser.add_argument('--colours', type=str, required=False)
77+
78+
group = parser.add_mutually_exclusive_group()
79+
group.add_argument('--file', type=str, required=False)
80+
group.add_argument('--dir', type=str, required=False)
81+
82+
parser.add_argument('--gpu', type=str, required=False)
83+
args = parser.parse_args()
84+
85+
if args.model and args.model != '' and os.path.exists(args.model):
86+
model = args.model
87+
print(model)
88+
89+
if not os.path.exists(model):
90+
print("model file [{}] is not exist\n".format(model))
91+
exit(1)
92+
93+
if args.weights and args.weights != '' and os.path.exists(args.weights):
94+
weights = args.weights
95+
print(weights)
96+
97+
if not os.path.exists(weights):
98+
print("weights file [{}] is not exist\n".format(weights))
99+
exit(1)
100+
101+
if args.colours and args.colours != '' and os.path.exists(args.colours):
102+
colours = args.colours
103+
print(colours)
104+
105+
if not os.path.exists(colours):
106+
print("colours file [{}] is not exist\n".format(colours))
107+
exit(1)
108+
109+
procFile = False
110+
file_path = ''
111+
if args.file and args.file != '' and os.path.exists(args.file):
112+
procFile = True
113+
file_path = args.file
114+
115+
procDir = False
116+
file_dir = ''
117+
if args.dir and args.dir != '' and os.path.exists(args.dir):
118+
procDir = True
119+
file_dir = args.dir
120+
121+
if procFile and not os.path.exists(file_path):
122+
print("image file [{}] is not exist\n".format(file_path))
123+
exit(1)
124+
125+
if procDir and not os.path.exists(file_dir):
126+
print("image dir [{}] is not exist\n".format(file_dir))
127+
exit(1)
128+
129+
gpu_id = 0
130+
if args.gpu:
131+
gpu_id = args.gpu
132+
133+
seg_model = ModelSegNetDemo(model=model, weights=weights, colours=colours, gpu_id=gpu_id)
134+
135+
if procDir:
136+
result_dir = os.path.join(file_dir, 'pspnet')
137+
if not os.path.exists(result_dir):
138+
os.makedirs(result_dir)
139+
140+
origin_list = os.listdir(file_dir)
141+
142+
for _image in origin_list:
143+
image_path = os.path.join(file_dir, _image)
144+
name_list = _image.split('.')
145+
if (len(name_list) < 2):
146+
print(image_path)
147+
continue
148+
file_name = name_list[0]
149+
ext_name = name_list[1]
150+
if ext_name == 'jpg' or ext_name == 'png':
151+
recog_path = os.path.join(result_dir, file_name + '.png')
152+
with open(image_path, 'rb') as f:
153+
image_data = f.read()
154+
recog_data = seg_model.do(image_data=image_data)
155+
156+
with open(recog_path, 'wb') as w:
157+
w.write(recog_data)
158+
159+
if procFile:
160+
name_list = file_path.split('/')
161+
part_count = len(name_list)
162+
if part_count < 2:
163+
exit(0)
164+
165+
file_name = name_list[part_count - 1]
166+
name_len = len(file_name)
167+
168+
file_dir = file_path[:(-1)*name_len]
169+
result_dir = os.path.join(file_dir, 'results')
170+
if not os.path.exists(result_dir):
171+
os.makedirs(result_dir)
172+
173+
name_list = file_name.split('.')
174+
if (len(name_list) < 2):
175+
print(file_path)
176+
177+
file_name = name_list[0]
178+
ext_name = name_list[1]
179+
if ext_name == 'jpg' or ext_name == 'png':
180+
recog_path = os.path.join(result_dir, file_name + '.png')
181+
with open(file_path, 'rb') as f:
182+
image_data = f.read()
183+
recog_data = seg_model.do(image_data=image_data)
184+
185+
with open(recog_path, 'wb') as w:
186+
w.write(recog_data)

0 commit comments

Comments
 (0)