Skip to content

Commit 7996881

Browse files
authored
Merge pull request tensorflow#305 from kirilg/branch_146182478
Upstream internal changes
2 parents 18e278e + cad5138 commit 7996881

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1186
-210
lines changed

tensorflow

Submodule tensorflow updated 967 files

tensorflow_serving/apis/BUILD

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ filegroup(
2323
load("//tensorflow_serving:serving.bzl", "serving_proto_library")
2424
load("//tensorflow_serving:serving.bzl", "serving_proto_library_py")
2525

26+
serving_proto_library(
27+
name = "get_model_metadata_proto",
28+
srcs = ["get_model_metadata.proto"],
29+
cc_api_version = 2,
30+
go_api_version = 2,
31+
java_api_version = 2,
32+
deps = [
33+
":model_proto",
34+
"@org_tensorflow//tensorflow/core:protos_all_cc",
35+
"@protobuf//:cc_wkt_protos",
36+
],
37+
)
38+
39+
serving_proto_library_py(
40+
name = "get_model_metadata_proto_py_pb2",
41+
srcs = ["get_model_metadata.proto"],
42+
proto_library = "get_model_metadata_proto",
43+
deps = [
44+
"@org_tensorflow//tensorflow/core:protos_all_py",
45+
],
46+
)
47+
2648
serving_proto_library(
2749
name = "model_proto",
2850
srcs = ["model.proto"],
@@ -72,12 +94,16 @@ serving_proto_library(
7294
go_api_version = 2,
7395
java_api_version = 2,
7496
deps = [
97+
":get_model_metadata_proto",
7598
":predict_proto",
7699
],
77100
)
78101

79102
py_library(
80103
name = "prediction_service_proto_py_pb2",
81104
srcs = ["prediction_service_pb2.py"],
82-
deps = [":predict_proto_py_pb2"],
105+
deps = [
106+
":get_model_metadata_proto_py_pb2",
107+
":predict_proto_py_pb2",
108+
],
83109
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
syntax = "proto3";
2+
3+
package tensorflow.serving;
4+
option cc_enable_arenas = true;
5+
6+
import "google/protobuf/any.proto";
7+
import "tensorflow/core/protobuf/meta_graph.proto";
8+
import "tensorflow_serving/apis/model.proto";
9+
10+
// Message returned for "signature_def" field.
11+
message SignatureDefMap {
12+
map<string, SignatureDef> signature_def = 1;
13+
};
14+
15+
message GetModelMetadataRequest {
16+
// Model Specification indicating which model we are querying for metadata.
17+
ModelSpec model_spec = 1;
18+
// Metadata fields to get. Currently supported: "signature_def".
19+
repeated string metadata_field = 2;
20+
}
21+
22+
message GetModelMetadataResponse {
23+
// Model Specification indicating which model this metadata belongs to.
24+
ModelSpec model_spec = 1;
25+
// Map of metadata field name to metadata field. The options for metadata
26+
// field name are listed in GetModelMetadataRequest. Currently supported:
27+
// "signature_def".
28+
map<string, google.protobuf.Any> metadata = 2;
29+
}

tensorflow_serving/apis/prediction_service.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ syntax = "proto3";
33
package tensorflow.serving;
44
option cc_enable_arenas = true;
55

6+
import "tensorflow_serving/apis/get_model_metadata.proto";
67
import "tensorflow_serving/apis/predict.proto";
78

89
// open source marker; do not remove
@@ -11,4 +12,8 @@ import "tensorflow_serving/apis/predict.proto";
1112
service PredictionService {
1213
// Predict -- provides access to loaded TensorFlow model.
1314
rpc Predict(PredictRequest) returns (PredictResponse);
15+
16+
// GetModelMetadata - provides access to metadata for loaded models.
17+
rpc GetModelMetadata(GetModelMetadataRequest)
18+
returns (GetModelMetadataResponse);
1419
}

tensorflow_serving/apis/prediction_service_pb2.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@
2727
_sym_db = _symbol_database.Default()
2828

2929

30+
from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
3031
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2
3132

3233

3334
DESCRIPTOR = _descriptor.FileDescriptor(
3435
name='tensorflow_serving/apis/prediction_service.proto',
3536
package='tensorflow.serving',
3637
syntax='proto3',
37-
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a%tensorflow_serving/apis/predict.proto2g\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponseB\x03\xf8\x01\x01\x62\x06proto3')
38+
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
3839
,
39-
dependencies=[tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
40+
dependencies=[tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
4041
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
4142

4243

@@ -54,7 +55,8 @@
5455

5556

5657
class PredictionServiceStub(object):
57-
"""PredictionService provides access to machine-learned models loaded by
58+
"""open source marker; do not remove
59+
PredictionService provides access to machine-learned models loaded by
5860
model_servers.
5961
"""
6062

@@ -69,10 +71,16 @@ def __init__(self, channel):
6971
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
7072
response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
7173
)
74+
self.GetModelMetadata = channel.unary_unary(
75+
'/tensorflow.serving.PredictionService/GetModelMetadata',
76+
request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
77+
response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
78+
)
7279

7380

7481
class PredictionServiceServicer(object):
75-
"""PredictionService provides access to machine-learned models loaded by
82+
"""open source marker; do not remove
83+
PredictionService provides access to machine-learned models loaded by
7684
model_servers.
7785
"""
7886

@@ -83,6 +91,13 @@ def Predict(self, request, context):
8391
context.set_details('Method not implemented!')
8492
raise NotImplementedError('Method not implemented!')
8593

94+
def GetModelMetadata(self, request, context):
95+
"""GetModelMetadata - provides access to metadata for loaded models.
96+
"""
97+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
98+
context.set_details('Method not implemented!')
99+
raise NotImplementedError('Method not implemented!')
100+
86101

87102
def add_PredictionServiceServicer_to_server(servicer, server):
88103
rpc_method_handlers = {
@@ -91,6 +106,11 @@ def add_PredictionServiceServicer_to_server(servicer, server):
91106
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
92107
response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
93108
),
109+
'GetModelMetadata': grpc.unary_unary_rpc_method_handler(
110+
servicer.GetModelMetadata,
111+
request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
112+
response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
113+
),
94114
}
95115
generic_handler = grpc.method_handlers_generic_handler(
96116
'tensorflow.serving.PredictionService', rpc_method_handlers)
@@ -103,13 +123,18 @@ class BetaPredictionServiceServicer(object):
103123
It is recommended to use the GA API (classes and functions in this
104124
file not marked beta) for all further purposes. This class was generated
105125
only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
106-
"""PredictionService provides access to machine-learned models loaded by
126+
"""open source marker; do not remove
127+
PredictionService provides access to machine-learned models loaded by
107128
model_servers.
108129
"""
109130
def Predict(self, request, context):
110131
"""Predict -- provides access to loaded TensorFlow model.
111132
"""
112133
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
134+
def GetModelMetadata(self, request, context):
135+
"""GetModelMetadata - provides access to metadata for loaded models.
136+
"""
137+
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
113138

114139

115140
class BetaPredictionServiceStub(object):
@@ -118,14 +143,20 @@ class BetaPredictionServiceStub(object):
118143
It is recommended to use the GA API (classes and functions in this
119144
file not marked beta) for all further purposes. This class was generated
120145
only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
121-
"""PredictionService provides access to machine-learned models loaded by
146+
"""open source marker; do not remove
147+
PredictionService provides access to machine-learned models loaded by
122148
model_servers.
123149
"""
124150
def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
125151
"""Predict -- provides access to loaded TensorFlow model.
126152
"""
127153
raise NotImplementedError()
128154
Predict.future = None
155+
def GetModelMetadata(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
156+
"""GetModelMetadata - provides access to metadata for loaded models.
157+
"""
158+
raise NotImplementedError()
159+
GetModelMetadata.future = None
129160

130161

131162
def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None):
@@ -135,12 +166,15 @@ def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, de
135166
file not marked beta) for all further purposes. This function was
136167
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
137168
request_deserializers = {
169+
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
138170
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
139171
}
140172
response_serializers = {
173+
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
141174
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
142175
}
143176
method_implementations = {
177+
('tensorflow.serving.PredictionService', 'GetModelMetadata'): face_utilities.unary_unary_inline(servicer.GetModelMetadata),
144178
('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict),
145179
}
146180
server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout)
@@ -154,13 +188,17 @@ def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=
154188
file not marked beta) for all further purposes. This function was
155189
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
156190
request_serializers = {
191+
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
157192
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
158193
}
159194
response_deserializers = {
195+
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
160196
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
161197
}
162198
cardinalities = {
199+
'GetModelMetadata': cardinality.Cardinality.UNARY_UNARY,
163200
'Predict': cardinality.Cardinality.UNARY_UNARY,
164201
}
165202
stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size)
166203
return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options)
204+
# @@protoc_insertion_point(module_scope)

tensorflow_serving/batching/batching_session_test.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ TEST(BatchingSessionTest, TensorSignatureFromSignatureDefs) {
134134
const SignatureDef signature_def_0 =
135135
CreateSignatureDef({{"x0", "x1"}, {"y0", "y1"}});
136136
const SignatureDef signature_def_1 =
137-
CreateSignatureDef({{"x1", "x2"}, {"y1", "y2"}});
137+
CreateSignatureDef({{"x1", "x2"}, {"y1", "y3"}});
138138
const TensorSignature tensor_signature =
139139
TensorSignatureFromSignatureDefs({signature_def_0, signature_def_1});
140140
EXPECT_THAT(tensor_signature.input_tensors,
141141
UnorderedElementsAre("x0", "x1", "x2"));
142142
EXPECT_THAT(tensor_signature.output_tensors,
143-
UnorderedElementsAre("y0", "y1", "y2"));
143+
UnorderedElementsAre("y0", "y1", "y3"));
144144
}
145145

146146
TEST(BatchingSessionTest, Basic) {
@@ -188,9 +188,9 @@ TEST(BatchingSessionTest, RequestThatDoesntMatchSignatureGetsRunAnyway) {
188188
std::unique_ptr<Session> batching_session;
189189
BatchingSessionOptions batching_session_options;
190190
TF_ASSERT_OK(CreateBasicBatchingSession(
191-
schedule_options, batching_session_options, {{"x2"}, {"y2"}},
191+
schedule_options, batching_session_options, {{"x2"}, {"y3"}},
192192
CreateHalfPlusTwoSession(), &batching_session));
193-
// Issue a request using x/y, which doesn't match the x2/y2 signature.
193+
// Issue a request using x/y, which doesn't match the x2/y3 signature.
194194
TestSingleRequest(100.0f, 42.0f, batching_session.get());
195195
}
196196

@@ -288,7 +288,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
288288
BatchingSessionOptions batching_session_options;
289289
std::unique_ptr<Session> batching_session;
290290
TF_ASSERT_OK(CreateBasicBatchingSession(
291-
schedule_options, batching_session_options, {{"x", "x2"}, {"y", "y2"}},
291+
schedule_options, batching_session_options, {{"x", "x2"}, {"y", "y3"}},
292292
CreateHalfPlusTwoSession(), &batching_session));
293293

294294
const Tensor input0 = test::AsTensor<float>({8.0f, 6.0f}, {2});
@@ -300,7 +300,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
300300
Env::Default()->StartThread(ThreadOptions(), "first_request_thread", [&] {
301301
std::vector<Tensor> outputs;
302302
TF_ASSERT_OK(batching_session->Run({{"x", input0}, {"x2", input1}},
303-
{"y", "y2"} /* outputs */,
303+
{"y", "y3"} /* outputs */,
304304
{} /* target nodes */, &outputs));
305305
ASSERT_EQ(2, outputs.size());
306306
test::ExpectTensorEqual<float>(expected_output0, outputs[0]);
@@ -310,7 +310,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
310310
ThreadOptions(), "second_request_thread", [&] {
311311
std::vector<Tensor> outputs;
312312
TF_ASSERT_OK(batching_session->Run({{"x2", input1}, {"x", input0}},
313-
{"y2", "y"} /* outputs */,
313+
{"y3", "y"} /* outputs */,
314314
{} /* target nodes */, &outputs));
315315
ASSERT_EQ(2, outputs.size());
316316
test::ExpectTensorEqual<float>(expected_output1, outputs[0]);
@@ -320,7 +320,7 @@ TEST(BatchingSessionTest, DifferentOrderForInputAndOutputTensors) {
320320
Env::Default()->StartThread(ThreadOptions(), "third_request_thread", [&] {
321321
std::vector<Tensor> outputs;
322322
TF_ASSERT_OK(batching_session->Run({{"x2", input1}, {"x", input0}},
323-
{"y", "y2"} /* outputs */,
323+
{"y", "y3"} /* outputs */,
324324
{} /* target nodes */, &outputs));
325325
ASSERT_EQ(2, outputs.size());
326326
test::ExpectTensorEqual<float>(expected_output0, outputs[0]);
@@ -349,7 +349,7 @@ TEST(BatchingSessionTest, MultipleSignatures) {
349349
std::unique_ptr<Session> batching_session;
350350
TF_CHECK_OK(CreateBatchingSession(
351351
batching_session_options, {{{{"x"}, {"y"}}, create_scheduler},
352-
{{{"x2"}, {"y2"}}, create_scheduler}},
352+
{{{"x2"}, {"y3"}}, create_scheduler}},
353353
CreateHalfPlusTwoSession(), &batching_session));
354354
ASSERT_EQ(2, schedulers.size());
355355

@@ -367,7 +367,7 @@ TEST(BatchingSessionTest, MultipleSignatures) {
367367
Tensor input = test::AsTensor<float>({100.0f, 42.0f}, {2});
368368
Tensor expected_output = test::AsTensor<float>({53.0f, 24.0f}, {2});
369369
std::vector<Tensor> outputs;
370-
TF_ASSERT_OK(batching_session->Run({{"x2", input}}, {"y2"} /* outputs */,
370+
TF_ASSERT_OK(batching_session->Run({{"x2", input}}, {"y3"} /* outputs */,
371371
{} /* target nodes */, &outputs));
372372
ASSERT_EQ(1, outputs.size());
373373
test::ExpectTensorEqual<float>(expected_output, outputs[0]);

tensorflow_serving/core/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ cc_test(
365365
srcs = ["aspired_versions_manager_builder_test.cc"],
366366
deps = [
367367
":aspired_versions_manager_builder",
368-
":eager_load_policy",
368+
":availability_preserving_policy",
369369
":servable_data",
370370
":servable_handle",
371371
":servable_state_monitor",
@@ -534,7 +534,7 @@ cc_test(
534534
deps = [
535535
":aspired_version_policy",
536536
":aspired_versions_manager",
537-
":eager_load_policy",
537+
":availability_preserving_policy",
538538
":loader",
539539
":manager",
540540
":servable_data",

tensorflow_serving/core/aspired_versions_manager_benchmark.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ limitations under the License.
4141
#include "tensorflow/core/platform/types.h"
4242
#include "tensorflow_serving/core/aspired_version_policy.h"
4343
#include "tensorflow_serving/core/aspired_versions_manager.h"
44-
#include "tensorflow_serving/core/eager_load_policy.h"
44+
#include "tensorflow_serving/core/availability_preserving_policy.h"
4545
#include "tensorflow_serving/core/loader.h"
4646
#include "tensorflow_serving/core/manager.h"
4747
#include "tensorflow_serving/core/servable_data.h"
@@ -74,7 +74,7 @@ class BenchmarkState {
7474
AspiredVersionsManager::Options options;
7575
// Do policy thread won't be run automatically.
7676
options.manage_state_interval_micros = -1;
77-
options.aspired_version_policy.reset(new EagerLoadPolicy());
77+
options.aspired_version_policy.reset(new AvailabilityPreservingPolicy());
7878
TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager_));
7979
}
8080

@@ -304,7 +304,7 @@ static void BM_GetServableHandle(const int iters) {
304304
AspiredVersionsManager::Options options;
305305
// Do policy thread won't be run automatically.
306306
options.manage_state_interval_micros = -1;
307-
options.aspired_version_policy.reset(new EagerLoadPolicy());
307+
options.aspired_version_policy.reset(new AvailabilityPreservingPolicy());
308308
std::unique_ptr<AspiredVersionsManager> manager;
309309
TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager));
310310
auto aspired_versions_callback = manager->GetAspiredVersionsCallback();

tensorflow_serving/core/aspired_versions_manager_builder_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818
#include <gmock/gmock.h>
1919
#include <gtest/gtest.h>
2020
#include "tensorflow/core/lib/strings/strcat.h"
21-
#include "tensorflow_serving/core/eager_load_policy.h"
21+
#include "tensorflow_serving/core/availability_preserving_policy.h"
2222
#include "tensorflow_serving/core/servable_data.h"
2323
#include "tensorflow_serving/core/servable_handle.h"
2424
#include "tensorflow_serving/core/servable_state_monitor.h"
@@ -46,7 +46,8 @@ class AspiredVersionsManagerBuilderTest : public ::testing::Test {
4646
servable_state_monitor_(servable_event_bus_.get()) {
4747
AspiredVersionsManagerBuilder::Options manager_options;
4848
manager_options.servable_event_bus = servable_event_bus_.get();
49-
manager_options.aspired_version_policy.reset(new EagerLoadPolicy());
49+
manager_options.aspired_version_policy.reset(
50+
new AvailabilityPreservingPolicy());
5051
TF_CHECK_OK(AspiredVersionsManagerBuilder::Create(
5152
std::move(manager_options), &builder_));
5253
}

0 commit comments

Comments
 (0)