Skip to content

Commit 36b4744

Browse files
Before this change, a Tensor contained a device pointer and a TensorInfoManager datastructure contained a mapping from device pointer to XlaTensorInfo object. This TensorInfoManager needed to be an Allocator too, so it could be informed when a Tensor is released.
After this change, a Tensor on an XlaDevice contains an XlaTensor object. The XlaTensor object is the equivalent of the old XlaTensorInfo object. This has advantages and drawbacks: + We don't need yet another allocator wrapper, as there is no side-band data to manage. + No hashtable lookups are required. - As XlaLocalLaunchOp could either be on an XlaDevice or a TF-classic device, we need some way to distinguish whether a Tensor is a TF-classic tensor (holds a device pointer) or an XlaTensor (we use a tagged pointer). As part of this, allocate ShapedBuffers using the XLA backend's allocator directly instead of a roundabout route where we: Wrapped the XLA allocator in an XlaDeviceAllocator Then wrapped the XlaDeviceAllocator in an XlaAllocator This leaves less to go wrong. Ideally we'd actually use StreamExecutor's allocator here, but this is less useful than XLA's as it doesn't provide helpful OOM messages (just returns nullptr). PiperOrigin-RevId: 191048184
1 parent 17d05bb commit 36b4744

12 files changed

+310
-290
lines changed

tensorflow/compiler/jit/BUILD

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ cc_library(
120120
)
121121

122122
cc_library(
123-
name = "xla_tensor_info",
124-
srcs = ["xla_tensor_info.cc"],
125-
hdrs = ["xla_tensor_info.h"],
123+
name = "xla_tensor",
124+
srcs = ["xla_tensor.cc"],
125+
hdrs = ["xla_tensor.h"],
126126
deps = [
127127
":common",
128+
"//tensorflow/compiler/tf2xla:common",
129+
"//tensorflow/compiler/xla/client:local_client",
128130
"//tensorflow/compiler/xla/service:shaped_buffer",
129131
"//tensorflow/core:core_cpu",
130132
"//tensorflow/core:core_cpu_internal",
@@ -154,7 +156,7 @@ cc_library(
154156
":common",
155157
":jit_compilation_passes",
156158
":xla_launch_util",
157-
":xla_tensor_info",
159+
":xla_tensor",
158160
"//tensorflow/compiler/jit/ops:xla_ops",
159161
"//tensorflow/compiler/tf2xla:common",
160162
"//tensorflow/compiler/tf2xla:dump_graph",
@@ -201,7 +203,7 @@ cc_library(
201203
deps = [
202204
":common",
203205
":xla_compilation_cache",
204-
":xla_tensor_info",
206+
":xla_tensor",
205207
"//tensorflow/compiler/tf2xla:xla_compiler",
206208
"//tensorflow/compiler/xla:status_macros",
207209
"//tensorflow/compiler/xla:statusor",

tensorflow/compiler/jit/kernels/xla_launch_op.cc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
116116

117117
const XlaDevice::Metadata* metadata;
118118
Status s = XlaDevice::GetMetadata(ctx, &metadata);
119-
120-
XlaTensorInfoManager* tensor_info_manager = nullptr;
121-
if (s.ok()) {
122-
tensor_info_manager = &metadata->tensor_info_manager();
123-
}
119+
bool allocate_xla_tensors = s.ok();
124120

125121
// Get the platform_id_ for XLA_* devices.
126122
if (platform_id_ == nullptr) {
@@ -134,16 +130,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
134130

135131
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
136132

137-
// Builds an XLA allocator for the device.
138-
XlaAllocator xla_allocator(client->platform(), ctx);
133+
XlaAllocator local_xla_allocator(client->backend().platform(),
134+
ctx->device()->GetAllocator({}));
135+
xla::DeviceMemoryAllocator* xla_allocator;
136+
// If we are on an XlaDevice, use the underlying XLA platform's allocator
137+
// directly. We could use the StreamExecutor's allocator which may
138+
// theoretically be more correct, but XLA returns a nice OOM message in a
139+
// Status and StreamExecutor does not.
140+
//
141+
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
142+
// (which local_xla_allocator above uses) as on an XlaDevice, this is a
143+
// dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
144+
// real allocator to allocate real buffers.
145+
if (allocate_xla_tensors) {
146+
xla_allocator = client->backend().memory_allocator();
147+
} else {
148+
xla_allocator = &local_xla_allocator;
149+
}
139150

140151
XlaCompiler::Options options;
141152
options.client = client;
142153
options.device_type = &cache->device_type();
143154
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
144155
options.graph_def_version = ctx->function_library()->graph_def_version();
145156
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
146-
options.device_allocator = &xla_allocator;
157+
options.device_allocator = xla_allocator;
147158

148159
const XlaCompiler::CompilationResult* kernel;
149160
xla::LocalExecutable* executable;
@@ -159,14 +170,14 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
159170
VLOG(1) << "Executing XLA Computation...";
160171

161172
XlaComputationLaunchContext launch_context(
162-
num_resource_args_, client, &xla_allocator, tensor_info_manager);
173+
num_resource_args_, client, xla_allocator, allocate_xla_tensors);
163174
launch_context.PopulateInputs(ctx, kernel, variables);
164175

165176
// Execute the computation.
166177
VLOG(2) << "Executing computation.";
167178
xla::ExecutableRunOptions run_options;
168179
run_options.set_stream(stream);
169-
run_options.set_allocator(&xla_allocator);
180+
run_options.set_allocator(xla_allocator);
170181
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
171182
Env* env = Env::Default();
172183
auto start_time = env->NowMicros();

tensorflow/compiler/jit/xla_compile_on_demand_op.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
5151
int64 num_resource_args = variables.size();
5252

5353
xla::LocalClient* client = metadata.client();
54-
XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
5554

5655
// Builds an XLA allocator for the device.
57-
XlaAllocator xla_allocator(client->platform(), ctx);
5856
XlaComputationLaunchContext launch_context(
59-
num_resource_args, client, &xla_allocator, tensor_info_manager);
57+
num_resource_args, client, client->backend().memory_allocator(), true);
6058

6159
launch_context.PopulateInputs(ctx, result, variables);
6260

@@ -67,7 +65,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
6765
VLOG(2) << "Executing computation.";
6866
xla::ExecutableRunOptions run_options;
6967
run_options.set_stream(stream);
70-
run_options.set_allocator(&xla_allocator);
68+
run_options.set_allocator(client->backend().memory_allocator());
7169
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
7270

7371
auto run_result = executable->Run(launch_context.arguments(), run_options);
@@ -106,25 +104,24 @@ Status XlaCompileOnDemandOp::Compile(
106104
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
107105
const XlaCompiler::CompilationResult** result,
108106
xla::LocalExecutable** executable) {
109-
XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
110-
111107
std::map<int, Tensor> constant_arguments;
112108
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
113109
const Tensor& device_tensor = ctx->input(i);
114-
if (const XlaTensorInfo* tensor_info =
115-
tensor_info_manager->GetTensorInfo(device_tensor)) {
116-
if (tensor_info->has_host_tensor() &&
110+
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
111+
if (xla_tensor->has_host_tensor() &&
117112
ShouldArgumentBeConstant(&ctx->op_kernel(), i)) {
118-
constant_arguments[i] = tensor_info->host_tensor();
113+
constant_arguments[i] = xla_tensor->host_tensor();
119114
}
120115
}
121116
if (constant_arguments.count(i) == 0 &&
122117
MustArgumentBeConstant(&ctx->op_kernel(), i)) {
123118
// Slow path; the argument is not available as a host constant so we must
124119
// fetch it synchronously.
125120
Tensor host_tensor;
121+
AllocatorAttributes attrs;
122+
attrs.set_on_host(true);
126123
TF_RETURN_IF_ERROR(ctx->allocate_temp(
127-
device_tensor.dtype(), device_tensor.shape(), &host_tensor));
124+
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
128125
Notification n;
129126
ctx->op_device_context()->CopyDeviceTensorToCPU(
130127
&device_tensor, "ConstantArgument",

tensorflow/compiler/jit/xla_device.cc

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
100100
}
101101

102102
std::unique_ptr<XlaDeviceAllocator> alloc =
103-
xla::MakeUnique<XlaDeviceAllocator>(backend, device_ordinal);
103+
xla::MakeUnique<XlaDeviceAllocator>();
104104
XlaDeviceAllocator* alloc_ptr = alloc.get();
105105
state.allocators_[{backend, device_ordinal}] = std::move(alloc);
106106
return alloc_ptr;
@@ -136,13 +136,11 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
136136
return Status::OK();
137137
}
138138

139-
XlaDevice::Metadata::Metadata(
140-
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
141-
std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager)
139+
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
140+
const DeviceType& device_type)
142141
: device_ordinal_(device_ordinal),
143142
device_type_(device_type),
144-
platform_(platform),
145-
tensor_info_manager_(*tensor_info_manager) {}
143+
platform_(platform) {}
146144

147145
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
148146

@@ -157,12 +155,9 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
157155
return device_type_;
158156
}
159157

160-
XlaTensorInfoManager& XlaDevice::Metadata::tensor_info_manager() const {
161-
return *tensor_info_manager_;
162-
}
163-
164158
/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
165159
const Metadata** metadata) {
160+
*metadata = nullptr;
166161
XlaDevice* xla_device =
167162
dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
168163
if (xla_device == nullptr) {
@@ -181,15 +176,11 @@ XlaDevice::XlaDevice(const SessionOptions& options,
181176
const DeviceType& jit_device_name, se::Platform* platform,
182177
bool transfer_as_literal)
183178
: LocalDevice(options, attrs),
184-
xla_metadata_(
185-
device_ordinal, platform, jit_device_name,
186-
// Pass tensor_info_manager_ by reference as it is initialized lazily.
187-
&tensor_info_manager_),
179+
xla_metadata_(device_ordinal, platform, jit_device_name),
188180
device_ordinal_(device_ordinal),
189181
jit_device_name_(jit_device_name),
190182
xla_allocator_(nullptr),
191183
platform_(platform),
192-
tensor_info_manager_(nullptr),
193184
transfer_as_literal_(transfer_as_literal) {}
194185

195186
XlaDevice::~XlaDevice() {}
@@ -215,7 +206,6 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
215206
xla::Backend* backend = client()->mutable_backend();
216207
xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
217208
backend, device_ordinal_);
218-
tensor_info_manager_.reset(new XlaTensorInfoManager(xla_allocator_));
219209
}
220210
return xla_allocator_;
221211
}
@@ -236,8 +226,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
236226
// Call GetAllocator for the side-effect of ensuring the allocator and
237227
// XlaTensorInfoManager is created.
238228
(void)GetAllocator({});
239-
auto ctx = new XlaDeviceContext(stream, tensor_info_manager_.get(),
240-
transfer_as_literal_);
229+
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
241230
for (Node* n : graph->nodes()) {
242231
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
243232
ctx->Ref();
@@ -285,8 +274,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
285274
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
286275
Notification n;
287276
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
288-
XlaTransferManager manager(stream, tensor_info_manager_.get(),
289-
transfer_as_literal_);
277+
XlaTransferManager manager(stream, client(), transfer_as_literal_);
290278
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
291279
[&n, &status](const Status& s) {
292280
status = s;

tensorflow/compiler/jit/xla_device.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ limitations under the License.
2626
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
2727
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
2828

29-
#include "tensorflow/compiler/jit/xla_tensor_info.h"
29+
#include "tensorflow/compiler/jit/xla_tensor.h"
3030
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
3131
#include "tensorflow/compiler/xla/client/local_client.h"
3232
#include "tensorflow/core/common_runtime/device_factory.h"
@@ -50,22 +50,19 @@ class XlaDevice : public LocalDevice {
5050
class Metadata {
5151
public:
5252
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
53-
const DeviceType& device_type,
54-
std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager);
53+
const DeviceType& device_type);
5554

5655
// The index of the device on this host.
5756
int device_ordinal() const;
5857

5958
perftools::gputools::Platform* platform() const;
6059
xla::LocalClient* client() const;
6160
const DeviceType& jit_device_type() const;
62-
XlaTensorInfoManager& tensor_info_manager() const;
6361

6462
private:
6563
const int device_ordinal_;
6664
const DeviceType device_type_;
6765
perftools::gputools::Platform* platform_; // Not owned.
68-
std::unique_ptr<XlaTensorInfoManager>& tensor_info_manager_;
6966

7067
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
7168
};
@@ -123,13 +120,6 @@ class XlaDevice : public LocalDevice {
123120
// copying back and forth between CPU and the device, and
124121
// computations enqueued by XLA.
125122
xla::Backend::StreamPtr stream_;
126-
// Manages sideband data about tensors, in particular the on-device shape tree
127-
// if the tensor requires multiple device buffers to represent (for example,
128-
// tuple shapes).
129-
// This is a unique_ptr because XlaTensorInfoManager is non-copy-constructible
130-
// and we need to initialize this lazily (as we also lazily initialize the
131-
// underlying allocator).
132-
std::unique_ptr<XlaTensorInfoManager> tensor_info_manager_;
133123
// Must we use XLA's transfer manager for correct host<->device transfers? if
134124
// false, we can use ThenMemcpy() instead.
135125
bool transfer_as_literal_;

tensorflow/compiler/jit/xla_device_context.cc

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,37 +27,32 @@ namespace se = ::perftools::gputools;
2727
namespace tensorflow {
2828

2929
// The allocator used for Tensors assigned to the XLA device.
30-
XlaDeviceAllocator::XlaDeviceAllocator(const xla::Backend* backend,
31-
int device_ordinal)
32-
: backend_(backend), device_ordinal_(device_ordinal) {}
33-
30+
XlaDeviceAllocator::XlaDeviceAllocator() {}
3431
XlaDeviceAllocator::~XlaDeviceAllocator() = default;
3532

3633
string XlaDeviceAllocator::Name() { return "xla"; }
3734

3835
void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
39-
se::DeviceMemoryBase dmem =
40-
backend_->memory_allocator()
41-
->Allocate(device_ordinal_, num_bytes, /*retry_on_failure=*/false)
42-
.ValueOrDie();
43-
VLOG(2) << "Allocated XLA device tensor " << dmem.opaque() << "(" << num_bytes
44-
<< ")";
45-
return dmem.opaque();
36+
// We always return an empty XlaTensor object, encoded as an opaque tagged
37+
// pointer. We can return an empty object and ignore num_bytes here because we
38+
// have control over all of the uses of this device tensor, and can lazily
39+
// allocate memory when used. This allows us to also know the shape of the
40+
// allocated Tensor, which is useful if the device's tensor representation
41+
// differs from the host.
42+
return XlaTensor::ToOpaquePointer(new XlaTensor());
4643
}
4744

4845
void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
49-
se::DeviceMemoryBase dmem(ptr);
50-
TF_CHECK_OK(backend_->memory_allocator()->Deallocate(device_ordinal_, &dmem));
51-
VLOG(2) << "Deallocated XLA device tensor " << ptr;
46+
delete XlaTensor::FromOpaquePointer(ptr);
5247
}
5348

5449
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
5550

56-
XlaTransferManager::XlaTransferManager(
57-
se::Stream* stream, XlaTensorInfoManager* tensor_info_manager,
58-
bool transfer_as_literal)
51+
XlaTransferManager::XlaTransferManager(se::Stream* stream,
52+
xla::LocalClient* client,
53+
bool transfer_as_literal)
5954
: stream_(stream),
60-
tensor_info_manager_(tensor_info_manager),
55+
client_(client),
6156
transfer_as_literal_(transfer_as_literal) {}
6257

6358
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@@ -74,9 +69,21 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
7469

7570
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
7671
const int64 total_bytes = cpu_tensor->TotalBytes();
77-
void* dst_ptr = DMAHelper::base(device_tensor);
78-
se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes);
7972

73+
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
74+
CHECK(xla_tensor);
75+
if (!xla_tensor->has_shaped_buffer()) {
76+
Status s = xla_tensor->AllocateShapedBuffer(
77+
device_tensor->dtype(), device_tensor->shape(), client_,
78+
stream_->parent()->device_ordinal());
79+
if (!s.ok()) {
80+
done(s);
81+
return;
82+
}
83+
}
84+
85+
se::DeviceMemoryBase dev_dst_ptr =
86+
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
8087
Status status;
8188
if (transfer_as_literal_) {
8289
status = xla::Unimplemented(
@@ -92,10 +99,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
9299
block_status.error_message().c_str());
93100
}
94101
}
95-
96-
XlaTensorInfo* tensor_info =
97-
tensor_info_manager_->GetOrCreateTensorInfo(*device_tensor);
98-
tensor_info->set_host_tensor(*cpu_tensor);
102+
xla_tensor->set_host_tensor(*cpu_tensor);
99103

100104
done(status);
101105
return;
@@ -119,8 +123,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
119123
<< device_tensor->NumElements();
120124

121125
const int64 total_bytes = cpu_tensor->TotalBytes();
122-
void* src_ptr = const_cast<void*>(DMAHelper::base(device_tensor));
123-
se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes);
126+
se::DeviceMemoryBase dev_src_ptr =
127+
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
124128
void* dst_ptr = DMAHelper::base(cpu_tensor);
125129

126130
Status status;
@@ -147,10 +151,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
147151
done(Status::OK());
148152
}
149153

150-
XlaDeviceContext::XlaDeviceContext(se::Stream* stream,
151-
XlaTensorInfoManager* tensor_info_manager,
154+
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
152155
bool transfer_as_literal)
153-
: manager_(stream, tensor_info_manager, transfer_as_literal) {}
156+
: manager_(stream, client, transfer_as_literal) {}
154157

155158
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
156159
Device* device,

0 commit comments

Comments
 (0)