Skip to content

Commit 7817ac8

Browse files
meheffernantensorflower-gardener
authored andcommitted
[XLA] Properly version outfeed and send operations in UserComputation.
Previously outfeed and send operations were unconditionally emitted during UserComputation lowering even if the outfeed/send was not in the requested version (computation snapshot). This CL versions these operations. Also, opportunistically improve logging in UserComputation, Service, and ComputationTracker which was used to root cause the underlying bug. Change: 148170893
1 parent f3405c2 commit 7817ac8

14 files changed

+628
-213
lines changed

tensorflow/compiler/xla/service/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ cc_library(
107107

108108
cc_library(
109109
name = "versioned_computation_handle",
110+
srcs = ["versioned_computation_handle.cc"],
110111
hdrs = ["versioned_computation_handle.h"],
111112
deps = [
112113
"//tensorflow/compiler/xla:xla_data_proto",
@@ -147,6 +148,22 @@ cc_library(
147148
],
148149
)
149150

151+
cc_test(
152+
name = "user_computation_test",
153+
srcs = ["user_computation_test.cc"],
154+
deps = [
155+
":user_computation",
156+
"//tensorflow/compiler/xla:literal_util",
157+
"//tensorflow/compiler/xla:shape_util",
158+
"//tensorflow/compiler/xla:status_macros",
159+
"//tensorflow/compiler/xla:test_helpers",
160+
"//tensorflow/compiler/xla:xla_data_proto",
161+
"//tensorflow/compiler/xla/service:hlo",
162+
"//tensorflow/core:test",
163+
"//tensorflow/core:test_main",
164+
],
165+
)
166+
150167
cc_library(
151168
name = "platform_util",
152169
srcs = ["platform_util.cc"],

tensorflow/compiler/xla/service/computation_tracker.cc

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ limitations under the License.
2626
#include "tensorflow/compiler/xla/types.h"
2727
#include "tensorflow/compiler/xla/util.h"
2828
#include "tensorflow/core/lib/strings/strcat.h"
29+
#include "tensorflow/core/lib/strings/stringprintf.h"
2930
#include "tensorflow/core/platform/logging.h"
3031

32+
using ::tensorflow::strings::Appendf;
33+
3134
namespace xla {
3235

3336
ComputationTracker::ComputationTracker() : next_computation_(1) {}
@@ -130,7 +133,7 @@ void ComputationTracker::ComputeComputationPostOrder(
130133
std::set<VersionedComputationHandle>* visited,
131134
std::list<VersionedComputationHandle>* post_order) const {
132135
if (visited->count(versioned_handle) > 0) {
133-
DCHECK_EQ(1, visited->count(versioned_handle));
136+
CHECK_EQ(1, visited->count(versioned_handle));
134137
return;
135138
}
136139

@@ -150,9 +153,14 @@ void ComputationTracker::ComputeComputationPostOrder(
150153

151154
StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
152155
const VersionedComputationHandle& entry_handle,
153-
bool include_unused_parameters) const {
156+
bool include_unreachable_instructions) const {
154157
tensorflow::mutex_lock lock(computation_mutex_);
155158

159+
VLOG(1) << "BuildHloModule(" << entry_handle
160+
<< ", include_unreachable_instructions="
161+
<< include_unreachable_instructions << ")";
162+
XLA_VLOG_LINES(1, ToStringInternal());
163+
156164
TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
157165
ResolveInternal(entry_handle.handle));
158166

@@ -174,6 +182,14 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
174182
return hlo_computations.at(versioned_handle);
175183
};
176184

185+
// Print the post-order list for this entry computation.
186+
if (VLOG_IS_ON(2)) {
187+
VLOG(2) << "Visiting UserComputations in post order:";
188+
for (const VersionedComputationHandle& versioned_handle : post_order) {
189+
VLOG(2) << " " << versioned_handle;
190+
}
191+
}
192+
177193
string module_name =
178194
tensorflow::strings::StrCat(entry_computation->name(), "_module");
179195
auto module = MakeUnique<HloModule>(module_name, entry_handle);
@@ -184,7 +200,7 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
184200
TF_ASSIGN_OR_RETURN(
185201
std::unique_ptr<HloComputation> hlo_computation,
186202
computation->BuildHloComputation(versioned_handle.version, resolver,
187-
include_unused_parameters));
203+
include_unreachable_instructions));
188204

189205
// Add the newly created computation to VersionedHandle-to-HloComputation
190206
// map.
@@ -201,4 +217,23 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
201217
return std::move(module);
202218
}
203219

220+
string ComputationTracker::ToString() const {
221+
tensorflow::mutex_lock lock(computation_mutex_);
222+
return ToStringInternal();
223+
}
224+
225+
string ComputationTracker::ToStringInternal() const {
226+
string out;
227+
Appendf(&out, "ComputationTracker(%p):\n", this);
228+
for (const auto& handle_computation : opaque_to_computation_) {
229+
int64 handle = handle_computation.first;
230+
const std::unique_ptr<UserComputation>& computation =
231+
handle_computation.second;
232+
Appendf(&out, " %4lld : %s \"%s\"\n", handle,
233+
computation->GetVersionedHandle().ToString().c_str(),
234+
computation->name().c_str());
235+
}
236+
return out;
237+
}
238+
204239
} // namespace xla

tensorflow/compiler/xla/service/computation_tracker.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ class ComputationTracker {
7272
// Builds an HLO module using the specified computation as the entry. The
7373
// module will include the entry computation as well as all computations which
7474
// are called directly or indirectly from the entry computation via operations
75-
// like "map". If include_unused_parameters is true, then all parameters are
76-
// lowered to HLO instructions even if they are not used. This ensures the
77-
// entry HloComputation has the same program shape (ProgramShape) as the entry
78-
// UserComputation.
75+
// like "map". If include_unreachable_instructions is true, then instructions
76+
// which are not reachable from the root are lowered into HloInstructions
77+
// including unreachable parameters. This ensures the entry HloComputation has
78+
// the same program shape (ProgramShape) as the entry UserComputation.
7979
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
8080
const VersionedComputationHandle& entry_handle,
81-
bool include_unused_parameters = true) const;
81+
bool include_unreachable_instructions = true) const;
82+
83+
string ToString() const;
8284

8385
private:
8486
// Bumps the next_computation_ number and returns the allocated number wrapped
@@ -117,6 +119,8 @@ class ComputationTracker {
117119
std::list<VersionedComputationHandle>* post_order) const
118120
EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
119121

122+
string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
123+
120124
// Guards the computation mapping. Marked mutable so that the Resolve method
121125
// can remain const; Resolve does't really modify the tracker in any way, but
122126
// it has to lock the mutex for safety.

tensorflow/compiler/xla/service/hlo_computation.cc

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,50 +111,43 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
111111

112112
TF_RET_CHECK(instruction->user_count() == 0);
113113
TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode()));
114-
std::queue<HloInstruction*> remove;
115-
remove.push(instruction);
116-
while (!remove.empty()) {
117-
HloInstruction* item = remove.front();
118-
remove.pop();
119-
if (item->user_count() != 0 || item == root_instruction_ ||
114+
std::unordered_set<HloInstruction*> removed;
115+
std::queue<HloInstruction*> worklist;
116+
worklist.push(instruction);
117+
while (!worklist.empty()) {
118+
HloInstruction* item = worklist.front();
119+
worklist.pop();
120+
121+
if (removed.count(item) != 0 || item->user_count() != 0 ||
122+
item == root_instruction() ||
120123
!HloComputation::IsRemovable(item->opcode())) {
121124
continue;
122125
}
123126
for (int i = 0; i < item->operand_count(); ++i) {
124-
remove.push(item->mutable_operand(i));
127+
worklist.push(item->mutable_operand(i));
125128
}
126129

127-
// If an instruction has the same operand more than once, we must not remove
128-
// it again.
129130
TF_RETURN_IF_ERROR(RemoveInstruction(item));
131+
removed.insert(item);
130132
}
131133
return Status::OK();
132134
}
133135

134-
StatusOr<bool> HloComputation::RemoveInstructionIfFound(
135-
HloInstruction* instruction) {
136+
Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
137+
VLOG(2) << "Removing instruction " << instruction->name()
138+
<< " from computation " << name();
136139
TF_RET_CHECK(IsRemovable(instruction->opcode()));
137140
TF_RET_CHECK(root_instruction() != instruction)
138-
<< "cannot remove root instruction";
141+
<< "cannot remove root instruction " << instruction->name();
139142
TF_RET_CHECK(instruction->user_count() == 0)
140-
<< "instruction with users cannot be removed";
143+
<< "instruction " << instruction->name()
144+
<< " has users and cannot be removed";
141145

142-
if (instruction_iterators_.count(instruction) == 0) {
143-
return false;
144-
}
145-
VLOG(2) << "Removing instruction " << instruction->name()
146-
<< " from computation " << name();
146+
TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
147147
auto inst_it = instruction_iterators_.at(instruction);
148148
(*inst_it)->set_parent(nullptr);
149149
instruction->DetachFromOperands();
150150
instructions_.erase(inst_it);
151-
return true;
152-
}
153-
154-
Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
155-
TF_ASSIGN_OR_RETURN(bool removed, RemoveInstructionIfFound(instruction));
156-
TF_RET_CHECK(removed) << instruction->ToString()
157-
<< " is not a member of computation " << name();
158151
return Status::OK();
159152
}
160153

@@ -511,17 +504,22 @@ HloComputation::ComputeTransitiveOperands() const {
511504
}
512505

513506
Status HloComputation::Accept(DfsHloVisitor* visitor) const {
514-
// Visit all dead roots.
507+
// Visit all roots. Build a vector of roots ahead of time because visitor
508+
// might delete the currently visited root which would invalidate the
509+
// iterator.
510+
std::vector<HloInstruction*> unreachable_roots;
515511
for (auto& instruction : instructions()) {
516512
if (instruction->user_count() == 0 &&
517513
instruction->control_successors().empty() &&
518514
instruction.get() != root_instruction()) {
519-
// Call FinishVisit only at the end.
520-
TF_RETURN_IF_ERROR(
521-
instruction->Accept(visitor, /*call_finish_visit=*/false));
515+
unreachable_roots.push_back(instruction.get());
522516
}
523517
}
524-
// Visit root instruction last.
518+
for (HloInstruction* root : unreachable_roots) {
519+
// Call FinishVisit only at the end.
520+
TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
521+
}
522+
// Visit the computation root instruction last.
525523
return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
526524
}
527525

tensorflow/compiler/xla/service/hlo_computation.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,6 @@ class HloComputation {
238238
HloInstruction* AddInstructionInternal(
239239
std::unique_ptr<HloInstruction> instruction);
240240

241-
// Remove an instruction from the computation if found. The instruction must
242-
// have no users. Instruction is deallocated with this call.
243-
// Return whether instruction was found and removed.
244-
StatusOr<bool> RemoveInstructionIfFound(HloInstruction* instruction);
245-
246241
// Fuses HLOs in instructions_to_fuse into fusion_instruction.
247242
//
248243
// Pre-condition: fusion_instruction's opcode is kFusion.

tensorflow/compiler/xla/service/hlo_computation_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,30 @@ TEST_F(HloComputationTest, CycleDetection) {
306306
testing::ContainsRegex("cycle is detecte"));
307307
}
308308

309+
TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
310+
// Test RemoveInstructionAndUnusedOperands with an instruction which has a
311+
// duplicated (dead) operand. This verifies that the operand is not deleted
312+
// twice.
313+
auto builder = HloComputation::Builder(TestName());
314+
auto constant = builder.AddInstruction(
315+
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
316+
auto dead_negate = builder.AddInstruction(
317+
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
318+
auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
319+
r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
320+
auto negate = builder.AddInstruction(
321+
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
322+
auto computation = builder.Build();
323+
324+
EXPECT_EQ(4, computation->instruction_count());
325+
EXPECT_EQ(negate, computation->root_instruction());
326+
327+
ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
328+
329+
EXPECT_EQ(2, computation->instruction_count());
330+
EXPECT_EQ(negate, computation->root_instruction());
331+
}
332+
309333
} // namespace
310334

311335
} // namespace xla

tensorflow/compiler/xla/service/local_service.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ LocalService::CompileAheadOfTime(
223223
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
224224
computation_tracker_.BuildHloModule(
225225
versioned_handle,
226-
/*include_unused_parameters=*/true));
226+
/*include_unreachable_instructions=*/true));
227227
hlo_modules.push_back(std::move(hlo_module));
228228

229229
TF_ASSIGN_OR_RETURN(

0 commit comments

Comments
 (0)