Skip to content

Commit b69a12b

Browse files
committed
Merge commit for internal changes
2 parents 6da3177 + 2261679 commit b69a12b

31 files changed

+364
-329
lines changed

tensorflow/compiler/jit/BUILD

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ cc_library(
120120
)
121121

122122
cc_library(
123-
name = "xla_tensor_info",
124-
srcs = ["xla_tensor_info.cc"],
125-
hdrs = ["xla_tensor_info.h"],
123+
name = "xla_tensor",
124+
srcs = ["xla_tensor.cc"],
125+
hdrs = ["xla_tensor.h"],
126126
deps = [
127127
":common",
128+
"//tensorflow/compiler/tf2xla:common",
129+
"//tensorflow/compiler/xla/client:local_client",
128130
"//tensorflow/compiler/xla/service:shaped_buffer",
129131
"//tensorflow/core:core_cpu",
130132
"//tensorflow/core:core_cpu_internal",
@@ -154,7 +156,7 @@ cc_library(
154156
":common",
155157
":jit_compilation_passes",
156158
":xla_launch_util",
157-
":xla_tensor_info",
159+
":xla_tensor",
158160
"//tensorflow/compiler/jit/ops:xla_ops",
159161
"//tensorflow/compiler/tf2xla:common",
160162
"//tensorflow/compiler/tf2xla:dump_graph",
@@ -201,7 +203,7 @@ cc_library(
201203
deps = [
202204
":common",
203205
":xla_compilation_cache",
204-
":xla_tensor_info",
206+
":xla_tensor",
205207
"//tensorflow/compiler/tf2xla:xla_compiler",
206208
"//tensorflow/compiler/xla:status_macros",
207209
"//tensorflow/compiler/xla:statusor",
@@ -359,6 +361,7 @@ tf_cc_test(
359361
"//tensorflow/core:core_cpu",
360362
"//tensorflow/core:framework",
361363
"//tensorflow/core:framework_internal",
364+
"//tensorflow/core:lib",
362365
"//tensorflow/core:test",
363366
"//tensorflow/core:test_main",
364367
"//tensorflow/core:testlib",

tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/core/graph/graph_constructor.h"
2626
#include "tensorflow/core/graph/graph_def_builder.h"
2727
#include "tensorflow/core/lib/core/status_test_util.h"
28+
#include "tensorflow/core/lib/strings/str_util.h"
2829
#include "tensorflow/core/platform/test.h"
2930
#include "tensorflow/core/util/equal_graph_def.h"
3031

@@ -771,7 +772,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
771772
Graph* graph = graph_ptr->get();
772773
for (const Node* n : graph->nodes()) {
773774
if (n->type_string() == "_Arg" &&
774-
StringPiece(n->name()).starts_with("const")) {
775+
str_util::StartsWith(n->name(), "const")) {
775776
++guaranteed_consts;
776777
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
777778
} else {
@@ -816,7 +817,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
816817
Graph* graph = graph_ptr->get();
817818
for (const Node* n : graph->nodes()) {
818819
if (n->type_string() == "_Arg" &&
819-
StringPiece(n->name()).starts_with("const")) {
820+
str_util::StartsWith(n->name(), "const")) {
820821
++guaranteed_consts;
821822
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
822823
} else {

tensorflow/compiler/jit/graph_to_functiondef.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/core/framework/node_def_util.h"
2525
#include "tensorflow/core/framework/types.h"
2626
#include "tensorflow/core/graph/graph.h"
27+
#include "tensorflow/core/lib/strings/str_util.h"
2728
#include "tensorflow/core/lib/strings/strcat.h"
2829

2930
namespace tensorflow {
@@ -229,7 +230,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
229230
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
230231
NodeDef* node_def = fdef->mutable_node_def(n_index);
231232
for (int i = 0; i < node_def->input_size(); ++i) {
232-
if (StringPiece(node_def->input(i)).starts_with("^")) {
233+
if (str_util::StartsWith(node_def->input(i), "^")) {
233234
// Control input
234235
const string normalized =
235236
node_names.Renormalize(node_def->input(i).substr(1));

tensorflow/compiler/jit/kernels/xla_launch_op.cc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
116116

117117
const XlaDevice::Metadata* metadata;
118118
Status s = XlaDevice::GetMetadata(ctx, &metadata);
119-
120-
XlaTensorInfoManager* tensor_info_manager = nullptr;
121-
if (s.ok()) {
122-
tensor_info_manager = &metadata->tensor_info_manager();
123-
}
119+
bool allocate_xla_tensors = s.ok();
124120

125121
// Get the platform_id_ for XLA_* devices.
126122
if (platform_id_ == nullptr) {
@@ -134,16 +130,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
134130

135131
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
136132

137-
// Builds an XLA allocator for the device.
138-
XlaAllocator xla_allocator(client->platform(), ctx);
133+
XlaAllocator local_xla_allocator(client->backend().platform(),
134+
ctx->device()->GetAllocator({}));
135+
xla::DeviceMemoryAllocator* xla_allocator;
136+
// If we are on an XlaDevice, use the underlying XLA platform's allocator
137+
// directly. We could use the StreamExecutor's allocator which may
138+
// theoretically be more correct, but XLA returns a nice OOM message in a
139+
// Status and StreamExecutor does not.
140+
//
141+
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
142+
// (which local_xla_allocator above uses) as on an XlaDevice, this is a
143+
// dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
144+
// real allocator to allocate real buffers.
145+
if (allocate_xla_tensors) {
146+
xla_allocator = client->backend().memory_allocator();
147+
} else {
148+
xla_allocator = &local_xla_allocator;
149+
}
139150

140151
XlaCompiler::Options options;
141152
options.client = client;
142153
options.device_type = &cache->device_type();
143154
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
144155
options.graph_def_version = ctx->function_library()->graph_def_version();
145156
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
146-
options.device_allocator = &xla_allocator;
157+
options.device_allocator = xla_allocator;
147158

148159
const XlaCompiler::CompilationResult* kernel;
149160
xla::LocalExecutable* executable;
@@ -159,14 +170,14 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
159170
VLOG(1) << "Executing XLA Computation...";
160171

161172
XlaComputationLaunchContext launch_context(
162-
num_resource_args_, client, &xla_allocator, tensor_info_manager);
173+
num_resource_args_, client, xla_allocator, allocate_xla_tensors);
163174
launch_context.PopulateInputs(ctx, kernel, variables);
164175

165176
// Execute the computation.
166177
VLOG(2) << "Executing computation.";
167178
xla::ExecutableRunOptions run_options;
168179
run_options.set_stream(stream);
169-
run_options.set_allocator(&xla_allocator);
180+
run_options.set_allocator(xla_allocator);
170181
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
171182
Env* env = Env::Default();
172183
auto start_time = env->NowMicros();

tensorflow/compiler/jit/mark_for_compilation_pass_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "tensorflow/core/graph/graph_def_builder.h"
2828
#include "tensorflow/core/graph/graph_def_builder_util.h"
2929
#include "tensorflow/core/lib/core/status_test_util.h"
30+
#include "tensorflow/core/lib/strings/str_util.h"
3031
#include "tensorflow/core/platform/test.h"
3132

3233
namespace tensorflow {
@@ -519,11 +520,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
519520

520521
Status status = MarkForCompilation(&graph);
521522
EXPECT_FALSE(status.ok());
522-
EXPECT_TRUE(StringPiece(status.ToString())
523-
.contains("Edge from c to a would create a cycle.\n"
524-
"+-> a\n"
525-
"| b\n"
526-
"+-- c\n"));
523+
EXPECT_TRUE(str_util::StrContains(status.ToString(),
524+
"Edge from c to a would create a cycle.\n"
525+
"+-> a\n"
526+
"| b\n"
527+
"+-- c\n"));
527528
}
528529

529530
TEST(XlaCompilationTest, Retval) {

tensorflow/compiler/jit/xla_compile_on_demand_op.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
5151
int64 num_resource_args = variables.size();
5252

5353
xla::LocalClient* client = metadata.client();
54-
XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
5554

5655
// Builds an XLA allocator for the device.
57-
XlaAllocator xla_allocator(client->platform(), ctx);
5856
XlaComputationLaunchContext launch_context(
59-
num_resource_args, client, &xla_allocator, tensor_info_manager);
57+
num_resource_args, client, client->backend().memory_allocator(), true);
6058

6159
launch_context.PopulateInputs(ctx, result, variables);
6260

@@ -67,7 +65,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
6765
VLOG(2) << "Executing computation.";
6866
xla::ExecutableRunOptions run_options;
6967
run_options.set_stream(stream);
70-
run_options.set_allocator(&xla_allocator);
68+
run_options.set_allocator(client->backend().memory_allocator());
7169
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
7270

7371
auto run_result = executable->Run(launch_context.arguments(), run_options);
@@ -106,25 +104,24 @@ Status XlaCompileOnDemandOp::Compile(
106104
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
107105
const XlaCompiler::CompilationResult** result,
108106
xla::LocalExecutable** executable) {
109-
XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
110-
111107
std::map<int, Tensor> constant_arguments;
112108
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
113109
const Tensor& device_tensor = ctx->input(i);
114-
if (const XlaTensorInfo* tensor_info =
115-
tensor_info_manager->GetTensorInfo(device_tensor)) {
116-
if (tensor_info->has_host_tensor() &&
110+
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
111+
if (xla_tensor->has_host_tensor() &&
117112
ShouldArgumentBeConstant(&ctx->op_kernel(), i)) {
118-
constant_arguments[i] = tensor_info->host_tensor();
113+
constant_arguments[i] = xla_tensor->host_tensor();
119114
}
120115
}
121116
if (constant_arguments.count(i) == 0 &&
122117
MustArgumentBeConstant(&ctx->op_kernel(), i)) {
123118
// Slow path; the argument is not available as a host constant so we must
124119
// fetch it synchronously.
125120
Tensor host_tensor;
121+
AllocatorAttributes attrs;
122+
attrs.set_on_host(true);
126123
TF_RETURN_IF_ERROR(ctx->allocate_temp(
127-
device_tensor.dtype(), device_tensor.shape(), &host_tensor));
124+
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
128125
Notification n;
129126
ctx->op_device_context()->CopyDeviceTensorToCPU(
130127
&device_tensor, "ConstantArgument",

tensorflow/compiler/jit/xla_device.cc

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
100100
}
101101

102102
std::unique_ptr<XlaDeviceAllocator> alloc =
103-
xla::MakeUnique<XlaDeviceAllocator>(backend, device_ordinal);
103+
xla::MakeUnique<XlaDeviceAllocator>();
104104
XlaDeviceAllocator* alloc_ptr = alloc.get();
105105
state.allocators_[{backend, device_ordinal}] = std::move(alloc);
106106
return alloc_ptr;
@@ -136,13 +136,11 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
136136
return Status::OK();
137137
}
138138

139-
XlaDevice::Metadata::Metadata(
140-
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
141-
std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager)
139+
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
140+
const DeviceType& device_type)
142141
: device_ordinal_(device_ordinal),
143142
device_type_(device_type),
144-
platform_(platform),
145-
tensor_info_manager_(*tensor_info_manager) {}
143+
platform_(platform) {}
146144

147145
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
148146

@@ -157,12 +155,9 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
157155
return device_type_;
158156
}
159157

160-
XlaTensorInfoManager& XlaDevice::Metadata::tensor_info_manager() const {
161-
return *tensor_info_manager_;
162-
}
163-
164158
/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
165159
const Metadata** metadata) {
160+
*metadata = nullptr;
166161
XlaDevice* xla_device =
167162
dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
168163
if (xla_device == nullptr) {
@@ -181,15 +176,11 @@ XlaDevice::XlaDevice(const SessionOptions& options,
181176
const DeviceType& jit_device_name, se::Platform* platform,
182177
bool transfer_as_literal)
183178
: LocalDevice(options, attrs),
184-
xla_metadata_(
185-
device_ordinal, platform, jit_device_name,
186-
// Pass tensor_info_manager_ by reference as it is initialized lazily.
187-
&tensor_info_manager_),
179+
xla_metadata_(device_ordinal, platform, jit_device_name),
188180
device_ordinal_(device_ordinal),
189181
jit_device_name_(jit_device_name),
190182
xla_allocator_(nullptr),
191183
platform_(platform),
192-
tensor_info_manager_(nullptr),
193184
transfer_as_literal_(transfer_as_literal) {}
194185

195186
XlaDevice::~XlaDevice() {}
@@ -215,7 +206,6 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
215206
xla::Backend* backend = client()->mutable_backend();
216207
xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
217208
backend, device_ordinal_);
218-
tensor_info_manager_.reset(new XlaTensorInfoManager(xla_allocator_));
219209
}
220210
return xla_allocator_;
221211
}
@@ -236,8 +226,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
236226
// Call GetAllocator for the side-effect of ensuring the allocator and
237227
// XlaTensorInfoManager is created.
238228
(void)GetAllocator({});
239-
auto ctx = new XlaDeviceContext(stream, tensor_info_manager_.get(),
240-
transfer_as_literal_);
229+
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
241230
for (Node* n : graph->nodes()) {
242231
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
243232
ctx->Ref();
@@ -285,8 +274,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
285274
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
286275
Notification n;
287276
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
288-
XlaTransferManager manager(stream, tensor_info_manager_.get(),
289-
transfer_as_literal_);
277+
XlaTransferManager manager(stream, client(), transfer_as_literal_);
290278
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
291279
[&n, &status](const Status& s) {
292280
status = s;

tensorflow/compiler/jit/xla_device.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ limitations under the License.
2626
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
2727
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
2828

29-
#include "tensorflow/compiler/jit/xla_tensor_info.h"
29+
#include "tensorflow/compiler/jit/xla_tensor.h"
3030
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
3131
#include "tensorflow/compiler/xla/client/local_client.h"
3232
#include "tensorflow/core/common_runtime/device_factory.h"
@@ -50,22 +50,19 @@ class XlaDevice : public LocalDevice {
5050
class Metadata {
5151
public:
5252
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
53-
const DeviceType& device_type,
54-
std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager);
53+
const DeviceType& device_type);
5554

5655
// The index of the device on this host.
5756
int device_ordinal() const;
5857

5958
perftools::gputools::Platform* platform() const;
6059
xla::LocalClient* client() const;
6160
const DeviceType& jit_device_type() const;
62-
XlaTensorInfoManager& tensor_info_manager() const;
6361

6462
private:
6563
const int device_ordinal_;
6664
const DeviceType device_type_;
6765
perftools::gputools::Platform* platform_; // Not owned.
68-
std::unique_ptr<XlaTensorInfoManager>& tensor_info_manager_;
6966

7067
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
7168
};
@@ -123,13 +120,6 @@ class XlaDevice : public LocalDevice {
123120
// copying back and forth between CPU and the device, and
124121
// computations enqueued by XLA.
125122
xla::Backend::StreamPtr stream_;
126-
// Manages sideband data about tensors, in particular the on-device shape tree
127-
// if the tensor requires multiple device buffers to represent (for example,
128-
// tuple shapes).
129-
// This is a unique_ptr because XlaTensorInfoManager is non-copy-constructible
130-
// and we need to initialize this lazily (as we also lazily initialize the
131-
// underlying allocator).
132-
std::unique_ptr<XlaTensorInfoManager> tensor_info_manager_;
133123
// Must we use XLA's transfer manager for correct host<->device transfers? if
134124
// false, we can use ThenMemcpy() instead.
135125
bool transfer_as_literal_;

0 commit comments

Comments
 (0)