@@ -27,33 +27,32 @@ namespace {
27
27
28
28
using ::testing::ElementsAre;
29
29
30
+ using flatbuffers::Offset;
31
+ using flatbuffers::Vector;
30
32
class ImportTest : public ::testing::Test {
31
33
protected:
32
34
template <typename T>
33
- flatbuffers::Offset<flatbuffers::Vector<unsigned char >> CreateDataVector (
34
- const std::vector<T>& data) {
35
+ Offset<Vector<unsigned char >> CreateDataVector (const std::vector<T>& data) {
35
36
return builder_.CreateVector (reinterpret_cast <const uint8_t *>(data.data ()),
36
37
sizeof (T) * data.size ());
37
38
}
38
- // This is a very simplistic model. We are not interested in testing all the
39
- // details here, since tf.mini's testing framework will be exercising all the
40
- // conversions multiple times, and the conversion of operators is tested by
41
- // separate unittests.
42
- void BuildTestModel () {
43
- // The tensors
39
+ Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers () {
40
+ auto buf0 = ::tflite::CreateBuffer (builder_, CreateDataVector<float >({}));
41
+ auto buf1 =
42
+ ::tflite::CreateBuffer (builder_, CreateDataVector<float >({1 .0f , 2 .0f }));
43
+ auto buf2 =
44
+ ::tflite::CreateBuffer (builder_, CreateDataVector<float >({3 .0f }));
45
+ return builder_.CreateVector (
46
+ std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
47
+ }
48
+
49
+ Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors () {
44
50
auto q = ::tflite::CreateQuantizationParameters (
45
51
builder_,
46
52
/* min=*/ builder_.CreateVector <float >({0 .1f }),
47
53
/* max=*/ builder_.CreateVector <float >({0 .2f }),
48
54
/* scale=*/ builder_.CreateVector <float >({0 .3f }),
49
55
/* zero_point=*/ builder_.CreateVector <int64_t >({100ll }));
50
- auto buf0 = ::tflite::CreateBuffer (builder_, CreateDataVector<float >({}));
51
- auto buf1 =
52
- ::tflite::CreateBuffer (builder_, CreateDataVector<float >({1 .0f , 2 .0f }));
53
- auto buf2 =
54
- ::tflite::CreateBuffer (builder_, CreateDataVector<float >({3 .0f }));
55
- auto buffers = builder_.CreateVector (
56
- std::vector<flatbuffers::Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
57
56
auto t1 = ::tflite::CreateTensor (builder_,
58
57
builder_.CreateVector <int >({1 , 2 , 3 , 4 }),
59
58
::tflite::TensorType_FLOAT32, 1 ,
@@ -62,17 +61,28 @@ class ImportTest : public ::testing::Test {
62
61
::tflite::CreateTensor (builder_, builder_.CreateVector<int >({2 , 1 }),
63
62
::tflite::TensorType_FLOAT32, 2,
64
63
builder_.CreateString(" tensor_two" ), q);
65
- auto tensors = builder_.CreateVector (
66
- std::vector<flatbuffers::Offset<::tflite::Tensor>>({t1, t2}));
64
+ return builder_.CreateVector (
65
+ std::vector<Offset<::tflite::Tensor>>({t1, t2}));
66
+ }
67
67
68
- // The operator codes.
68
+ Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes () {
69
69
auto c1 =
70
70
::tflite::CreateOperatorCode (builder_, ::tflite::BuiltinOperator_CUSTOM,
71
71
builder_.CreateString(" custom_op_one" ));
72
72
auto c2 = ::tflite::CreateOperatorCode (
73
73
builder_, ::tflite::BuiltinOperator_CONV_2D, 0 );
74
- auto opcodes = builder_.CreateVector (
75
- std::vector<flatbuffers::Offset<::tflite::OperatorCode>>({c1, c2}));
74
+ return builder_.CreateVector (
75
+ std::vector<Offset<::tflite::OperatorCode>>({c1, c2}));
76
+ }
77
+
78
+ // This is a very simplistic model. We are not interested in testing all the
79
+ // details here, since tf.mini's testing framework will be exercising all the
80
+ // conversions multiple times, and the conversion of operators is tested by
81
+ // separate unittests.
82
+ void BuildTestModel () {
83
+ auto buffers = BuildBuffers ();
84
+ auto tensors = BuildTensors ();
85
+ auto opcodes = BuildOpCodes ();
76
86
77
87
auto subgraph = ::tflite::CreateSubGraph (builder_, tensors, 0 , 0 , 0 );
78
88
std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector (
@@ -133,6 +143,19 @@ TEST_F(ImportTest, Tensors) {
133
143
EXPECT_EQ (100 , q->zero_point );
134
144
}
135
145
146
+ TEST_F (ImportTest, NoSubGraphs) {
147
+ auto buffers = BuildBuffers ();
148
+ auto opcodes = BuildOpCodes ();
149
+ auto subgraphs = 0 ; // no subgraphs in this model
150
+ auto comment = builder_.CreateString (" " );
151
+ builder_.Finish (::tflite::CreateModel (builder_, TFLITE_SCHEMA_VERSION,
152
+ opcodes, subgraphs, comment, buffers));
153
+ input_model_ = ::tflite::GetModel (builder_.GetBufferPointer ());
154
+
155
+ EXPECT_DEATH (Import (ModelFlags (), InputModelAsString ()),
156
+ " Number of subgraphs in tflite should be exactly 1." );
157
+ }
158
+
136
159
// TODO(ahentz): still need tests for Operators and IOTensors.
137
160
138
161
} // namespace
0 commit comments