Skip to content

Commit 39b5dcb

Browse files
Create load_model_tensorflow_cpp.md
1 parent 87bc175 commit 39b5dcb

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

load_model_tensorflow_cpp.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Load model with TensorFlow C++ API
2+
3+
## 1. Build neural network and save model
4+
5+
```
6+
from sklearn.datasets import make_classification
7+
from tensorflow.keras import Sequential
8+
from tensorflow.keras.layers import Dense
9+
from tensorflow.keras.optimizers import SGD
10+
# create the dataset
11+
X, y = make_classification(n_samples=1000, n_features=4, n_classes=2, random_state=1)
12+
# determine the number of input features
13+
n_features = X.shape[1]
14+
# define model
15+
model = Sequential()
16+
model.add(Dense(10, activation='relu', kernel_initializer='he_normal', input_shape=(n_features,)))
17+
model.add(Dense(1, activation='sigmoid'))
18+
# compile the model
19+
sgd = SGD(learning_rate=0.001, momentum=0.8)
20+
model.compile(optimizer=sgd, loss='binary_crossentropy')
21+
# fit the model
22+
model.fit(X, y, epochs=100, batch_size=32, verbose=1, validation_split=0.3)
23+
# save model to file
24+
model.save('model')
25+
```
26+
27+
This will create a folder `model` and save model as protobuf files:
28+
```
29+
ls model/
30+
assets keras_metadata.pb saved_model.pb variables
31+
```
32+
33+
## 2. Load model with C++ API
34+
35+
Create a new file, e.g., `load_model.cpp`
36+
```
37+
#include <tensorflow/cc/saved_model/loader.h>
38+
#include <tensorflow/cc/saved_model/tag_constants.h>
39+
40+
using namespace tensorflow;
41+
42+
int main() {
43+
44+
const std::string export_dir = "./model/";
45+
46+
// Load
47+
SavedModelBundle model_bundle;
48+
SessionOptions session_options = SessionOptions();
49+
RunOptions run_options = RunOptions();
50+
Status status = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &model_bundle);
51+
52+
if (!status.ok()) {
53+
std::cerr << "Failed: " << status;
54+
}
55+
return 0;
56+
}
57+
```
58+
59+
Compile source code
60+
```
61+
g++ -Wall -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 \
62+
load_model.cpp -o load_model.o \
63+
-I/usr/local/tensorflow/include/ -L/usr/local/tensorflow/lib -ltensorflow_cc -ltensorflow_framework
64+
```
65+
66+
Add TensorFlow lib into lib env var
67+
```
68+
export LD_LIBRARY_PATH=/usr/local/tensorflow/lib/:$LD_LIBRARY_PATH
69+
```
70+
71+
Run the executable
72+
```
73+
./load_model.o
74+
75+
2021-12-30 22:14:10.621434: I tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: ./model/
76+
2021-12-30 22:14:10.630099: I tensorflow/cc/saved_model/reader.cc:90] Reading meta graph with tags { serve }
77+
2021-12-30 22:14:10.630299: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: ./model/
78+
2021-12-30 22:14:10.714968: I tensorflow/cc/saved_model/loader.cc:211] Restoring SavedModel bundle.
79+
2021-12-30 22:14:10.760375: I tensorflow/cc/saved_model/loader.cc:195] Running initialization op on SavedModel bundle at path: ./model/
80+
2021-12-30 22:14:10.765069: I tensorflow/cc/saved_model/loader.cc:283] SavedModel load for tags { serve }; Status: success: OK. Took 143842 microseconds.
81+
```

0 commit comments

Comments
 (0)