Skip to content

Commit ce41a15

Browse files
committed
Import and convert FP16 weights from TensorFlow
1 parent 1caca21 commit ce41a15

File tree

2 files changed

+82
-11
lines changed

2 files changed

+82
-11
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,59 @@ void blobShapeFromTensor(const tensorflow::TensorProto &tensor, MatShape& shape)
6363
{
6464
const tensorflow::TensorShapeProto &_shape = tensor.tensor_shape();
6565
int i, n = _shape.dim_size();
66-
shape.resize(n);
66+
if (n)
67+
{
68+
shape.resize(n);
6769

68-
for (i = 0; i < n; i++)
69-
shape[i] = (int)_shape.dim(i).size();
70+
for (i = 0; i < n; i++)
71+
shape[i] = (int)_shape.dim(i).size();
72+
}
73+
else
74+
shape.resize(1, 1); // Scalar.
7075
}
7176
else
7277
{
7378
CV_Error(Error::StsError, "Unknown shape of input tensor");
7479
}
7580
}
7681

82+
static Mat getTensorContent(const tensorflow::TensorProto &tensor)
83+
{
84+
std::string content = tensor.tensor_content();
85+
switch (tensor.dtype())
86+
{
87+
case tensorflow::DT_FLOAT:
88+
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
89+
case tensorflow::DT_DOUBLE:
90+
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
91+
case tensorflow::DT_HALF:
92+
{
93+
Mat halfs;
94+
if (!content.empty())
95+
{
96+
static const int kHalfSize = 2;
97+
halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
98+
}
99+
else
100+
{
101+
const RepeatedField<int32_t>& field = tensor.half_val();
102+
CV_Assert(!field.empty());
103+
Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
104+
ints.convertTo(halfs, CV_16UC1);
105+
}
106+
// Reinterpret as a signed shorts just for a convertFp16 call.
107+
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
108+
Mat floats(halfs.size(), CV_32FC1);
109+
convertFp16(halfsSigned, floats);
110+
return floats;
111+
}
112+
default:
113+
CV_Error(Error::StsError, "Tensor's data type is not supported");
114+
break;
115+
}
116+
return Mat();
117+
}
118+
77119
template <typename T>
78120
void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
79121
{
@@ -90,11 +132,12 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
90132

91133
dstBlob.create(shape, CV_32F);
92134

93-
int size = tensor.tensor_content().size() / sizeof(T);
135+
Mat tensorContent = getTensorContent(tensor);
136+
int size = tensorContent.total();
94137
CV_Assert(size == (int)dstBlob.total());
95138

96139
float *dstData = dstBlob.ptr<float>();
97-
const T *data = reinterpret_cast<const T*>(tensor.tensor_content().c_str());
140+
const T *data = reinterpret_cast<const T*>(tensorContent.data);
98141

99142
if (dims == 4)
100143
{
@@ -125,6 +168,7 @@ void blobFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
125168
{
126169
switch (tensor.dtype()) {
127170
case tensorflow::DT_FLOAT:
171+
case tensorflow::DT_HALF:
128172
parseTensor<float>(tensor, dstBlob);
129173
break;
130174
case tensorflow::DT_DOUBLE:
@@ -406,7 +450,8 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
406450
int dims = (int)shape.size();
407451

408452
// TODO: other blob types
409-
CV_Assert(tensor.dtype() == tensorflow::DT_FLOAT);
453+
CV_Assert(tensor.dtype() == tensorflow::DT_FLOAT ||
454+
tensor.dtype() == tensorflow::DT_HALF);
410455
CV_Assert(dims == 4);
411456

412457
// REORDER kernel HWIO to OIHW
@@ -416,11 +461,12 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
416461

417462
dstBlob.create(shape, CV_32F);
418463

419-
int size = tensor.tensor_content().size() / sizeof(float);
464+
Mat tensorContent = getTensorContent(tensor);
465+
int size = tensorContent.total();
420466
CV_Assert(size == (int)dstBlob.total());
421467

422468
float *dstData = dstBlob.ptr<float>();
423-
const float *data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
469+
const float *data = reinterpret_cast<const float*>(tensorContent.data);
424470

425471
int out_c = shape[0], input_c = shape[1], height = shape[2], width = shape[3];
426472
int total = out_c*input_c*height*width;
@@ -753,7 +799,16 @@ void TFImporter::populateNet(Net dstNet)
753799
// Multiplication by constant.
754800
CV_Assert(layer.input_size() == 2);
755801

756-
float scale = getConstBlob(layer, value_id).float_val()[0];
802+
float scale;
803+
if (!getConstBlob(layer, value_id).float_val().empty())
804+
scale = getConstBlob(layer, value_id).float_val()[0];
805+
else
806+
{
807+
Mat scaleMat;
808+
blobFromTensor(getConstBlob(layer, value_id), scaleMat);
809+
CV_Assert(scaleMat.total() == 1 && scaleMat.type() == CV_32FC1);
810+
scale = scaleMat.at<float>(0, 0);
811+
}
757812
layerParams.set("scale", scale);
758813

759814
int id = dstNet.addLayer(name, "Power", layerParams);

modules/dnn/test/test_tf_importer.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ static std::string path(const std::string& file)
7676
return findDataFile("dnn/tensorflow/" + file, false);
7777
}
7878

79-
static void runTensorFlowNet(const std::string& prefix)
79+
static void runTensorFlowNet(const std::string& prefix,
80+
double l1 = 1e-5, double lInf = 1e-4)
8081
{
8182
std::string netPath = path(prefix + "_net.pb");
8283
std::string inpPath = path(prefix + "_in.npy");
@@ -89,7 +90,7 @@ static void runTensorFlowNet(const std::string& prefix)
8990

9091
net.setInput(input);
9192
cv::Mat output = net.forward();
92-
normAssert(target, output);
93+
normAssert(target, output, "", l1, lInf);
9394
}
9495

9596
TEST(Test_TensorFlow, single_conv)
@@ -130,4 +131,19 @@ TEST(Test_TensorFlow, deconvolution)
130131
runTensorFlowNet("deconvolution");
131132
}
132133

134+
TEST(Test_TensorFlow, fp16)
135+
{
136+
const float l1 = 1e-3;
137+
const float lInf = 1e-2;
138+
runTensorFlowNet("fp16_single_conv", l1, lInf);
139+
runTensorFlowNet("fp16_deconvolution", l1, lInf);
140+
runTensorFlowNet("fp16_max_pool_odd_same", l1, lInf);
141+
runTensorFlowNet("fp16_padding_valid", l1, lInf);
142+
runTensorFlowNet("fp16_eltwise_add_mul", l1, lInf);
143+
runTensorFlowNet("fp16_max_pool_odd_valid", l1, lInf);
144+
runTensorFlowNet("fp16_pad_and_concat", l1, lInf);
145+
runTensorFlowNet("fp16_max_pool_even", l1, lInf);
146+
runTensorFlowNet("fp16_padding_same", l1, lInf);
147+
}
148+
133149
}

0 commit comments

Comments
 (0)