@@ -63,17 +63,59 @@ void blobShapeFromTensor(const tensorflow::TensorProto &tensor, MatShape& shape)
63
63
{
64
64
const tensorflow::TensorShapeProto &_shape = tensor.tensor_shape ();
65
65
int i, n = _shape.dim_size ();
66
- shape.resize (n);
66
+ if (n)
67
+ {
68
+ shape.resize (n);
67
69
68
- for (i = 0 ; i < n; i++)
69
- shape[i] = (int )_shape.dim (i).size ();
70
+ for (i = 0 ; i < n; i++)
71
+ shape[i] = (int )_shape.dim (i).size ();
72
+ }
73
+ else
74
+ shape.resize (1 , 1 ); // Scalar.
70
75
}
71
76
else
72
77
{
73
78
CV_Error (Error::StsError, " Unknown shape of input tensor" );
74
79
}
75
80
}
76
81
82
+ static Mat getTensorContent (const tensorflow::TensorProto &tensor)
83
+ {
84
+ std::string content = tensor.tensor_content ();
85
+ switch (tensor.dtype ())
86
+ {
87
+ case tensorflow::DT_FLOAT:
88
+ return Mat (1 , content.size () / sizeof (float ), CV_32FC1, (void *)content.c_str ()).clone ();
89
+ case tensorflow::DT_DOUBLE:
90
+ return Mat (1 , content.size () / sizeof (double ), CV_64FC1, (void *)content.c_str ()).clone ();
91
+ case tensorflow::DT_HALF:
92
+ {
93
+ Mat halfs;
94
+ if (!content.empty ())
95
+ {
96
+ static const int kHalfSize = 2 ;
97
+ halfs = Mat (1 , content.size () / kHalfSize , CV_16UC1, (void *)content.c_str ());
98
+ }
99
+ else
100
+ {
101
+ const RepeatedField<int32_t >& field = tensor.half_val ();
102
+ CV_Assert (!field.empty ());
103
+ Mat ints (1 , field.size (), CV_32SC1, (void *)field.data ());
104
+ ints.convertTo (halfs, CV_16UC1);
105
+ }
106
+ // Reinterpret as a signed shorts just for a convertFp16 call.
107
+ Mat halfsSigned (halfs.size (), CV_16SC1, halfs.data );
108
+ Mat floats (halfs.size (), CV_32FC1);
109
+ convertFp16 (halfsSigned, floats);
110
+ return floats;
111
+ }
112
+ default :
113
+ CV_Error (Error::StsError, " Tensor's data type is not supported" );
114
+ break ;
115
+ }
116
+ return Mat ();
117
+ }
118
+
77
119
template <typename T>
78
120
void parseTensor (const tensorflow::TensorProto &tensor, Mat &dstBlob)
79
121
{
@@ -90,11 +132,12 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
90
132
91
133
dstBlob.create (shape, CV_32F);
92
134
93
- int size = tensor.tensor_content ().size () / sizeof (T);
135
+ Mat tensorContent = getTensorContent (tensor);
136
+ int size = tensorContent.total ();
94
137
CV_Assert (size == (int )dstBlob.total ());
95
138
96
139
float *dstData = dstBlob.ptr <float >();
97
- const T *data = reinterpret_cast <const T*>(tensor. tensor_content (). c_str () );
140
+ const T *data = reinterpret_cast <const T*>(tensorContent. data );
98
141
99
142
if (dims == 4 )
100
143
{
@@ -125,6 +168,7 @@ void blobFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
125
168
{
126
169
switch (tensor.dtype ()) {
127
170
case tensorflow::DT_FLOAT:
171
+ case tensorflow::DT_HALF:
128
172
parseTensor<float >(tensor, dstBlob);
129
173
break ;
130
174
case tensorflow::DT_DOUBLE:
@@ -406,7 +450,8 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
406
450
int dims = (int )shape.size ();
407
451
408
452
// TODO: other blob types
409
- CV_Assert (tensor.dtype () == tensorflow::DT_FLOAT);
453
+ CV_Assert (tensor.dtype () == tensorflow::DT_FLOAT ||
454
+ tensor.dtype () == tensorflow::DT_HALF);
410
455
CV_Assert (dims == 4 );
411
456
412
457
// REORDER kernel HWIO to OIHW
@@ -416,11 +461,12 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
416
461
417
462
dstBlob.create (shape, CV_32F);
418
463
419
- int size = tensor.tensor_content ().size () / sizeof (float );
464
+ Mat tensorContent = getTensorContent (tensor);
465
+ int size = tensorContent.total ();
420
466
CV_Assert (size == (int )dstBlob.total ());
421
467
422
468
float *dstData = dstBlob.ptr <float >();
423
- const float *data = reinterpret_cast <const float *>(tensor. tensor_content (). c_str () );
469
+ const float *data = reinterpret_cast <const float *>(tensorContent. data );
424
470
425
471
int out_c = shape[0 ], input_c = shape[1 ], height = shape[2 ], width = shape[3 ];
426
472
int total = out_c*input_c*height*width;
@@ -753,7 +799,16 @@ void TFImporter::populateNet(Net dstNet)
753
799
// Multiplication by constant.
754
800
CV_Assert (layer.input_size () == 2 );
755
801
756
- float scale = getConstBlob (layer, value_id).float_val ()[0 ];
802
+ float scale;
803
+ if (!getConstBlob (layer, value_id).float_val ().empty ())
804
+ scale = getConstBlob (layer, value_id).float_val ()[0 ];
805
+ else
806
+ {
807
+ Mat scaleMat;
808
+ blobFromTensor (getConstBlob (layer, value_id), scaleMat);
809
+ CV_Assert (scaleMat.total () == 1 && scaleMat.type () == CV_32FC1);
810
+ scale = scaleMat.at <float >(0 , 0 );
811
+ }
757
812
layerParams.set (" scale" , scale);
758
813
759
814
int id = dstNet.addLayer (name, " Power" , layerParams);
0 commit comments