Skip to content

Commit 3358b89

Browse files
committed
Merge pull request opencv#9591 from dkurt:feature_dnn_caffe_importer_fp16
2 parents 73298ea + 8646d5f commit 3358b89

File tree

7 files changed

+935
-461
lines changed

7 files changed

+935
-461
lines changed

modules/dnn/include/opencv2/dnn/dnn.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,19 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
701701
CV_EXPORTS_W Mat blobFromImages(const std::vector<Mat>& images, double scalefactor=1.0,
702702
Size size = Size(), const Scalar& mean = Scalar(), bool swapRB=true);
703703

704+
/** @brief Convert all weights of Caffe network to half precision floating point.
705+
* @param src Path to origin model from Caffe framework contains single
706+
* precision floating point weights (usually has `.caffemodel` extension).
707+
* @param dst Path to destination model with updated weights.
708+
*
709+
* @note Shrinked model has no origin float32 weights so it can't be used
710+
* in origin Caffe framework anymore. However the structure of data
711+
* is taken from NVidia's Caffe fork: https://github.com/NVIDIA/caffe.
712+
* So the resulting model may be used there.
713+
*/
714+
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst);
715+
716+
704717
//! @}
705718
CV__DNN_EXPERIMENTAL_NS_END
706719
}

modules/dnn/misc/caffe/caffe.pb.cc

Lines changed: 637 additions & 444 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

modules/dnn/misc/caffe/caffe.pb.h

Lines changed: 143 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

modules/dnn/src/caffe/caffe.proto

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ syntax = "proto2";
5050

5151
package caffe;
5252

53+
// NVidia's Caffe feature is used to store fp16 weights, https://github.com/NVIDIA/caffe:
54+
// Math and storage types
55+
enum Type {
56+
DOUBLE = 0;
57+
FLOAT = 1;
58+
FLOAT16 = 2;
59+
INT = 3; // math not supported
60+
UINT = 4; // math not supported
61+
}
62+
5363
// Specifies the shape (dimensions) of a Blob.
5464
message BlobShape {
5565
repeated int64 dim = 1 [packed = true];
@@ -62,6 +72,11 @@ message BlobProto {
6272
repeated double double_data = 8 [packed = true];
6373
repeated double double_diff = 9 [packed = true];
6474

75+
// NVidia's Caffe fields begin.
76+
optional Type raw_data_type = 10;
77+
optional bytes raw_data = 12 [packed = false];
78+
// NVidia's Caffe fields end.
79+
6580
// 4D dimensions -- deprecated. Use "shape" instead.
6681
optional int32 num = 1 [default = 0];
6782
optional int32 channels = 2 [default = 0];

modules/dnn/src/caffe/caffe_importer.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,28 @@ class CaffeImporter : public Importer
225225
blobShapeFromProto(pbBlob, shape);
226226

227227
dstBlob.create((int)shape.size(), &shape[0], CV_32F);
228-
CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
229-
230-
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
231228
float *dstData = dstBlob.ptr<float>();
229+
if (pbBlob.data_size())
230+
{
231+
// Single precision floats.
232+
CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
233+
234+
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
232235

233-
for (int i = 0; i < pbBlob.data_size(); i++)
234-
dstData[i] = pbBlob.data(i);
236+
for (int i = 0; i < pbBlob.data_size(); i++)
237+
dstData[i] = pbBlob.data(i);
238+
}
239+
else
240+
{
241+
// Half precision floats.
242+
CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);
243+
std::string raw_data = pbBlob.raw_data();
244+
245+
CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());
246+
247+
Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());
248+
convertFp16(halfs, dstBlob);
249+
}
235250
}
236251

237252
void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
//
5+
// Copyright (C) 2017, Intel Corporation, all rights reserved.
6+
// Third party copyrights are property of their respective owners.
7+
8+
#include "../precomp.hpp"
9+
10+
#ifdef HAVE_PROTOBUF
11+
#include <fstream>
12+
#include "caffe_io.hpp"
13+
#endif
14+
15+
namespace cv { namespace dnn {
16+
CV__DNN_EXPERIMENTAL_NS_BEGIN
17+
18+
#ifdef HAVE_PROTOBUF
19+
20+
void shrinkCaffeModel(const String& src, const String& dst)
21+
{
22+
CV_TRACE_FUNCTION();
23+
24+
caffe::NetParameter net;
25+
ReadNetParamsFromBinaryFileOrDie(src.c_str(), &net);
26+
27+
for (int i = 0; i < net.layer_size(); ++i)
28+
{
29+
caffe::LayerParameter* lp = net.mutable_layer(i);
30+
for (int j = 0; j < lp->blobs_size(); ++j)
31+
{
32+
caffe::BlobProto* blob = lp->mutable_blobs(j);
33+
CV_Assert(blob->data_size() != 0); // float32 array.
34+
35+
Mat floats(1, blob->data_size(), CV_32FC1, (void*)blob->data().data());
36+
Mat halfs(1, blob->data_size(), CV_16SC1);
37+
convertFp16(floats, halfs); // Convert to float16.
38+
39+
blob->clear_data(); // Clear float32 data.
40+
41+
// Set float16 data.
42+
blob->set_raw_data(halfs.data, halfs.total() * halfs.elemSize());
43+
blob->set_raw_data_type(caffe::FLOAT16);
44+
}
45+
}
46+
size_t msgSize = net.ByteSizeLong();
47+
std::vector<uint8_t> output(msgSize);
48+
net.SerializeWithCachedSizesToArray(&output[0]);
49+
50+
std::ofstream ofs(dst.c_str(), std::ios::binary);
51+
ofs.write((const char*)&output[0], msgSize);
52+
ofs.close();
53+
}
54+
55+
#else
56+
57+
void shrinkCaffeModel(const String& src, const String& dst)
58+
{
59+
CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models");
60+
}
61+
62+
#endif // HAVE_PROTOBUF
63+
64+
CV__DNN_EXPERIMENTAL_NS_END
65+
}} // namespace

0 commit comments

Comments
 (0)