Skip to content

Commit 2cd0169

Browse files
coderzbxzhangbenxing
authored andcommitted
增加线上处理脚本
1 parent 641d26a commit 2cd0169

File tree

3 files changed

+302
-6
lines changed

3 files changed

+302
-6
lines changed

SelfScripts/data_clr.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,10 @@
2828
Label("train" , 16, (0, 80, 100)),
2929
Label("motorcycle" , 17, (0, 0, 230)),
3030
Label("bicycle" , 18, (119, 11, 32))
31-
}
31+
}
32+
33+
if __name__ == '__main__':
34+
print("{")
35+
for l in cityscapes_19_label:
36+
print('\'' + str(l.name) + "': " + str(l.color) + ",")
37+
print("}")

SelfScripts/demo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import argparse
77

88

9-
class ModelSegNetDemo:
9+
class ModelDemo:
1010
def __init__(self, model, weights, colours, gpu_id=3):
1111
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
1212
caffe.set_mode_gpu()
@@ -65,9 +65,9 @@ def do(self, image_data):
6565
colours = ''
6666

6767
parser = argparse.ArgumentParser()
68-
parser.add_argument('--model', type=str, required=False)
69-
parser.add_argument('--weights', type=str, required=False)
70-
parser.add_argument('--colours', type=str, required=False)
68+
parser.add_argument('--model', type=str, required=True)
69+
parser.add_argument('--weights', type=str, required=True)
70+
parser.add_argument('--colours', type=str, required=True)
7171

7272
group = parser.add_mutually_exclusive_group()
7373
group.add_argument('--file', type=str, required=False)
@@ -124,7 +124,7 @@ def do(self, image_data):
124124
if args.gpu:
125125
gpu_id = args.gpu
126126

127-
seg_model = ModelSegNetDemo(model=model, weights=weights, colours=colours, gpu_id=gpu_id)
127+
seg_model = ModelDemo(model=model, weights=weights, colours=colours, gpu_id=gpu_id)
128128

129129
if procDir:
130130
result_dir = os.path.join(file_dir, 'pspnet')

SelfScripts/demo_online.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import caffe
2+
import numpy as np
3+
import cv2
4+
import os
5+
import json
6+
import time
7+
import argparse
8+
import requests
9+
from PIL import Image
10+
from io import BytesIO
11+
import multiprocessing
12+
from multiprocessing import Manager
13+
14+
15+
class DownloadTask:
16+
def __init__(self, track_point_id, exit_flag=False):
17+
self.track_point_id = track_point_id
18+
self.exit_flag = exit_flag
19+
20+
21+
class RecogTask:
22+
def __init__(self, track_point_id, image_data, exit_flag=False):
23+
self.track_point_id = track_point_id
24+
self.image_data = image_data
25+
self.exit_flag = exit_flag
26+
27+
28+
class SaveTask:
29+
def __init__(self, track_point_id, pred_data, exit_flag=False):
30+
self.track_point_id = track_point_id
31+
self.pred_data = pred_data
32+
self.exit_flag = exit_flag
33+
34+
35+
class ModelDemo:
36+
def __init__(self, model, weights, colours, manager, gpu_id=3):
37+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
38+
caffe.set_mode_gpu()
39+
40+
self.manager = manager
41+
self.download_queue = self.manager.Queue()
42+
self.recog_queue = self.manager.Queue()
43+
self.save_queue = self.manager.Queue()
44+
45+
self.weights = weights
46+
self.model = model
47+
self.colours = colours
48+
49+
self.net = caffe.Net(self.model,
50+
self.weights,
51+
caffe.TEST)
52+
53+
def start_queue(self, url, track_id):
54+
_url = url + "/track/get"
55+
56+
try:
57+
res = requests.post(url=_url, data={'trackId': track_id})
58+
track_info = res.text
59+
60+
track_data = json.loads(track_info)
61+
code = track_data["code"]
62+
63+
if code != "0":
64+
return False
65+
66+
point_data = track_data["result"]["pointList"]
67+
68+
for point in point_data:
69+
track_point_id = point["trackPointId"]
70+
next_task = DownloadTask(track_point_id=track_point_id, exit_flag=False)
71+
self.download_queue.put(next_task)
72+
73+
except Exception as e:
74+
print(e.args[0])
75+
76+
next_task = DownloadTask(track_point_id=None, exit_flag=True)
77+
self.download_queue.put(next_task)
78+
next_task = DownloadTask(track_point_id=None, exit_flag=True)
79+
self.download_queue.put(next_task)
80+
81+
return
82+
83+
def download(self, url):
84+
if self.download_queue.empty():
85+
time.sleep(1)
86+
87+
while True:
88+
task = self.download_queue.get()
89+
if not isinstance(task, DownloadTask):
90+
break
91+
92+
if task.exit_flag:
93+
next_task = RecogTask(track_point_id=None, image_data=None, exit_flag=True)
94+
self.recog_queue.put(next_task)
95+
break
96+
97+
image_data = None
98+
99+
_url = url + "image/get"
100+
data = {
101+
"trackPointId": task.track_point_id,
102+
"type": "00",
103+
"seq": "004",
104+
"imageType": "jpg"
105+
}
106+
try:
107+
res_data = requests.post(url=_url, data=data)
108+
i = Image.open(BytesIO(res_data.content))
109+
output = BytesIO()
110+
i.save(output, format='JPEG')
111+
112+
image_data = output.getvalue()
113+
except Exception as e:
114+
print e.args[0]
115+
116+
next_task = RecogTask(track_point_id=task.track_point_id, image_data=image_data)
117+
self.recog_queue.put(next_task)
118+
119+
return
120+
121+
def recognition(self):
122+
if self.recog_queue.empty():
123+
time.sleep(1)
124+
125+
while True:
126+
task = self.recog_queue.get()
127+
if not isinstance(task, RecogTask):
128+
break
129+
130+
if task.exit_flag:
131+
next_task = SaveTask(track_point_id=None, pred_data=None, exit_flag=True)
132+
self.save_queue.put(next_task)
133+
next_task = SaveTask(track_point_id=None, pred_data=None, exit_flag=True)
134+
self.save_queue.put(next_task)
135+
136+
break
137+
138+
if task.image_data is None:
139+
continue
140+
141+
pred_data = self.do(image_data=task.image_data)
142+
next_task = SaveTask(track_point_id=task.track_point_id, pred_data=pred_data)
143+
self.save_queue.put(next_task)
144+
return
145+
146+
def save(self, dir):
147+
if self.save_queue.empty():
148+
time.sleep(1)
149+
150+
while True:
151+
task = self.save_queue.get()
152+
if not isinstance(task, SaveTask):
153+
break
154+
155+
if task.exit_flag:
156+
break
157+
158+
file_name = task.track_point_id + ".png"
159+
recog_path = os.path.join(dir, file_name)
160+
161+
with open(recog_path, 'wb') as w:
162+
w.write(task.pred_data)
163+
164+
return
165+
166+
def do(self, image_data):
167+
168+
input_shape = self.net.blobs['data'].data.shape
169+
label_colours = cv2.imread(self.colours).astype(np.uint8)
170+
171+
start = time.time()
172+
173+
image = np.asarray(bytearray(image_data), dtype="uint8")
174+
origin_frame = cv2.imdecode(image, cv2.IMREAD_COLOR)
175+
176+
width = origin_frame.shape[1]
177+
height = origin_frame.shape[0]
178+
179+
frame = cv2.resize(origin_frame, (input_shape[3], input_shape[2]))
180+
input_image = frame.transpose((2, 0, 1))
181+
input_image = np.asarray([input_image])
182+
self.net.forward_all(data=input_image)
183+
184+
predict = self.net.blobs['conv6_interp'].data[0, :, :, :]
185+
ind = np.argmax(predict, axis=0)
186+
out_pred = np.resize(ind, (3, input_shape[2], input_shape[3]))
187+
out_pred = out_pred.transpose(1, 2, 0).astype(np.uint8)
188+
out_rgb = np.zeros(out_pred.shape, dtype=np.uint8)
189+
190+
cv2.LUT(out_pred, label_colours, out_rgb)
191+
rgb_frame = cv2.resize(out_rgb, (width, height), interpolation=cv2.INTER_NEAREST)
192+
193+
img_array = cv2.imencode('.png', rgb_frame)
194+
img_data = img_array[1]
195+
pred_data = img_data.tostring()
196+
197+
end = time.time()
198+
print('%30s' % 'Processed results in ', str((end - start) * 1000), 'ms\n')
199+
200+
return pred_data
201+
202+
203+
if __name__ == '__main__':
204+
205+
time1 = time.time()
206+
207+
weights = ''
208+
model = ''
209+
colours = ''
210+
211+
parser = argparse.ArgumentParser()
212+
parser.add_argument('--model', type=str, required=True)
213+
parser.add_argument('--weights', type=str, required=True)
214+
parser.add_argument('--colours', type=str, required=True)
215+
parser.add_argument('--track_id', type=str, required=True)
216+
parser.add_argument('--url', type=str, required=True)
217+
parser.add_argument('--dir', type=str, required=True)
218+
parser.add_argument('--gpu', type=str, required=False)
219+
args = parser.parse_args()
220+
221+
if args.model and args.model != '' and os.path.exists(args.model):
222+
model = args.model
223+
print(model)
224+
225+
if not os.path.exists(model):
226+
print("model file [{}] is not exist\n".format(model))
227+
exit(1)
228+
229+
if args.weights and args.weights != '' and os.path.exists(args.weights):
230+
weights = args.weights
231+
print(weights)
232+
233+
if not os.path.exists(weights):
234+
print("weights file [{}] is not exist\n".format(weights))
235+
exit(1)
236+
237+
if args.colours and args.colours != '' and os.path.exists(args.colours):
238+
colours = args.colours
239+
print(colours)
240+
241+
if not os.path.exists(colours):
242+
print("colours file [{}] is not exist\n".format(colours))
243+
exit(1)
244+
245+
gpu_id = 0
246+
if args.gpu:
247+
gpu_id = args.gpu
248+
249+
save_dir = args.dir
250+
if not os.path.exists(save_dir):
251+
os.makedirs(save_dir)
252+
253+
manager = Manager()
254+
seg_model = ModelDemo(model=model, weights=weights, colours=colours, manager=manager, gpu_id=gpu_id)
255+
256+
seg_model.start_queue(url=args.url, track_id=args.track_id)
257+
seg_model.download(url=args.url)
258+
seg_model.recognition()
259+
seg_model.save(dir=args.dir)
260+
261+
# all_process = []
262+
#
263+
# download_process1 = multiprocessing.Process(target=seg_model.download, args=(args.url,))
264+
# download_process2 = multiprocessing.Process(target=seg_model.download, args=(args.url,))
265+
# all_process.append(download_process1)
266+
# all_process.append(download_process2)
267+
#
268+
# # recognition only one process
269+
# recog_process = multiprocessing.Process(target=seg_model.recognition)
270+
# all_process.append(recog_process)
271+
#
272+
# save_process1 = multiprocessing.Process(target=seg_model.save, args=(args.dir, ))
273+
# save_process2 = multiprocessing.Process(target=seg_model.save, args=(args.dir,))
274+
# all_process.append(save_process1)
275+
# all_process.append(save_process2)
276+
#
277+
# for proc_ in all_process:
278+
# if not isinstance(proc_, multiprocessing.Process):
279+
# break
280+
# proc_.start()
281+
#
282+
# for proc_ in all_process:
283+
# if not isinstance(proc_, multiprocessing.Process):
284+
# break
285+
# proc_.join()
286+
287+
time2 = time.time()
288+
289+
print("finish in {} s\n".format(time2 - time1))
290+

0 commit comments

Comments
 (0)