Skip to content

Commit 91c1d76

Browse files
committed
Merge pull request opencv#10156 from dkurt:refresh_torch_enet_sample
2 parents 5b27fb5 + 2af6f68 commit 91c1d76

File tree

1 file changed

+47
-81
lines changed

1 file changed

+47
-81
lines changed

samples/dnn/torch_enet.cpp

Lines changed: 47 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,35 @@ const String keys =
2020
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
2121
"{model m || path to Torch .net model file (model_best.net) }"
2222
"{image i || path to image file }"
23-
"{c_names c || path to file with classnames for channels (optional, categories.txt) }"
2423
"{result r || path to save output blob (optional, binary format, NCHW order) }"
2524
"{show s || whether to show all output channels or not}"
26-
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
27-
;
25+
"{o_blob || output blob's name. If empty, last blob's name in net is used}";
2826

29-
static void colorizeSegmentation(const Mat &score, Mat &segm,
30-
Mat &legend, vector<String> &classNames, vector<Vec3b> &colors);
31-
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames);
27+
static const int kNumClasses = 20;
28+
29+
static const String classes[] = {
30+
"Background", "Road", "Sidewalk", "Building", "Wall", "Fence", "Pole",
31+
"TrafficLight", "TrafficSign", "Vegetation", "Terrain", "Sky", "Person",
32+
"Rider", "Car", "Truck", "Bus", "Train", "Motorcycle", "Bicycle"
33+
};
34+
35+
static const Vec3b colors[] = {
36+
Vec3b(0, 0, 0), Vec3b(244, 126, 205), Vec3b(254, 83, 132), Vec3b(192, 200, 189),
37+
Vec3b(50, 56, 251), Vec3b(65, 199, 228), Vec3b(240, 178, 193), Vec3b(201, 67, 188),
38+
Vec3b(85, 32, 33), Vec3b(116, 25, 18), Vec3b(162, 33, 72), Vec3b(101, 150, 210),
39+
Vec3b(237, 19, 16), Vec3b(149, 197, 72), Vec3b(80, 182, 21), Vec3b(141, 5, 207),
40+
Vec3b(189, 156, 39), Vec3b(235, 170, 186), Vec3b(133, 109, 144), Vec3b(231, 160, 96)
41+
};
42+
43+
static void showLegend();
44+
45+
static void colorizeSegmentation(const Mat &score, Mat &segm);
3246

3347
int main(int argc, char **argv)
3448
{
3549
CommandLineParser parser(argc, argv, keys);
3650

37-
if (parser.has("help"))
51+
if (parser.has("help") || argc == 1)
3852
{
3953
parser.printMessage();
4054
return 0;
@@ -49,7 +63,6 @@ int main(int argc, char **argv)
4963
return 0;
5064
}
5165

52-
String classNamesFile = parser.get<String>("c_names");
5366
String resultFile = parser.get<String>("result");
5467

5568
//! [Read model and initialize network]
@@ -63,17 +76,11 @@ int main(int argc, char **argv)
6376
exit(-1);
6477
}
6578

66-
Size origSize = img.size();
67-
Size inputImgSize = cv::Size(1024, 512);
68-
69-
if (inputImgSize != origSize)
70-
resize(img, img, inputImgSize); //Resize image to input size
71-
72-
Mat inputBlob = blobFromImage(img, 1./255); //Convert Mat to image batch
79+
Mat inputBlob = blobFromImage(img, 1./255, Size(1024, 512), Scalar(), true, false); //Convert Mat to image batch
7380
//! [Prepare blob]
7481

7582
//! [Set input blob]
76-
net.setInput(inputBlob, ""); //set the network input
83+
net.setInput(inputBlob); //set the network input
7784
//! [Set input blob]
7885

7986
TickMeter tm;
@@ -102,41 +109,47 @@ int main(int argc, char **argv)
102109

103110
if (parser.has("show"))
104111
{
105-
std::vector<String> classNames;
106-
vector<cv::Vec3b> colors;
107-
if(!classNamesFile.empty()) {
108-
colors = readColors(classNamesFile, classNames);
109-
}
110-
Mat segm, legend;
111-
colorizeSegmentation(result, segm, legend, classNames, colors);
112+
Mat segm, show;
113+
colorizeSegmentation(result, segm);
114+
showLegend();
112115

113-
Mat show;
116+
cv::resize(segm, segm, img.size(), 0, 0, cv::INTER_NEAREST);
114117
addWeighted(img, 0.1, segm, 0.9, 0.0, show);
115118

116-
cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST);
117119
imshow("Result", show);
118-
if(classNames.size())
119-
imshow("Legend", legend);
120120
waitKey();
121121
}
122-
123122
return 0;
124123
} //main
125124

126-
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames, vector<Vec3b> &colors)
125+
static void showLegend()
126+
{
127+
static const int kBlockHeight = 30;
128+
129+
cv::Mat legend(kBlockHeight * kNumClasses, 200, CV_8UC3);
130+
for(int i = 0; i < kNumClasses; i++)
131+
{
132+
cv::Mat block = legend.rowRange(i * kBlockHeight, (i + 1) * kBlockHeight);
133+
block.setTo(colors[i]);
134+
putText(block, classes[i], Point(0, kBlockHeight / 2), FONT_HERSHEY_SIMPLEX, 0.5, Vec3b(255, 255, 255));
135+
}
136+
imshow("Legend", legend);
137+
}
138+
139+
static void colorizeSegmentation(const Mat &score, Mat &segm)
127140
{
128141
const int rows = score.size[2];
129142
const int cols = score.size[3];
130143
const int chns = score.size[1];
131144

132-
cv::Mat maxCl(rows, cols, CV_8UC1);
133-
cv::Mat maxVal(rows, cols, CV_32FC1);
134-
for (int ch = 0; ch < chns; ch++)
145+
Mat maxCl = Mat::zeros(rows, cols, CV_8UC1);
146+
Mat maxVal(rows, cols, CV_32FC1, score.data);
147+
for (int ch = 1; ch < chns; ch++)
135148
{
136149
for (int row = 0; row < rows; row++)
137150
{
138151
const float *ptrScore = score.ptr<float>(0, ch, row);
139-
uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
152+
uint8_t *ptrMaxCl = maxCl.ptr<uint8_t>(row);
140153
float *ptrMaxVal = maxVal.ptr<float>(row);
141154
for (int col = 0; col < cols; col++)
142155
{
@@ -153,57 +166,10 @@ static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vecto
153166
for (int row = 0; row < rows; row++)
154167
{
155168
const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
156-
cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
169+
Vec3b *ptrSegm = segm.ptr<Vec3b>(row);
157170
for (int col = 0; col < cols; col++)
158171
{
159172
ptrSegm[col] = colors[ptrMaxCl[col]];
160173
}
161174
}
162-
163-
if (classNames.size() == colors.size())
164-
{
165-
int blockHeight = 30;
166-
legend.create(blockHeight*(int)classNames.size(), 200, CV_8UC3);
167-
for(int i = 0; i < (int)classNames.size(); i++)
168-
{
169-
cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight);
170-
block = colors[i];
171-
putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
172-
}
173-
}
174-
}
175-
176-
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames)
177-
{
178-
vector<cv::Vec3b> colors;
179-
classNames.clear();
180-
181-
ifstream fp(filename.c_str());
182-
if (!fp.is_open())
183-
{
184-
cerr << "File with colors not found: " << filename << endl;
185-
exit(-1);
186-
}
187-
188-
string line;
189-
while (!fp.eof())
190-
{
191-
getline(fp, line);
192-
if (line.length())
193-
{
194-
stringstream ss(line);
195-
196-
string name; ss >> name;
197-
int temp;
198-
cv::Vec3b color;
199-
ss >> temp; color[0] = (uchar)temp;
200-
ss >> temp; color[1] = (uchar)temp;
201-
ss >> temp; color[2] = (uchar)temp;
202-
classNames.push_back(name);
203-
colors.push_back(color);
204-
}
205-
}
206-
207-
fp.close();
208-
return colors;
209175
}

0 commit comments

Comments
 (0)