Skip to content

Commit e33538d

Browse files
author
Akshay Modi
committed
Merge commit for internal changes
2 parents 5fa7b03 + c5436b9 commit e33538d

File tree

100 files changed

+3746
-1446
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+3746
-1446
lines changed

tensorflow/c/eager/c_api_test.cc

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) {
142142
TFE_ContextOptions* opts = TFE_NewContextOptions();
143143
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
144144
status);
145-
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
146145
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
146+
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
147+
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
148+
TFE_DEVICE_PLACEMENT_EXPLICIT);
147149
TFE_Context* ctx = TFE_NewContext(opts, status);
148150
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
149151
TFE_DeleteContextOptions(opts);
@@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) {
205207
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
206208
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
207209

210+
void TestRemoteExecuteSilentCopies(bool async) {
211+
tensorflow::ServerDef server_def = GetServerDef(2);
212+
213+
// This server def has the task index set to 0.
214+
string serialized = server_def.SerializeAsString();
215+
216+
server_def.set_task_index(1);
217+
218+
std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
219+
ASSERT_TRUE(
220+
tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
221+
.ok());
222+
ASSERT_TRUE(worker_server->Start().ok());
223+
224+
TF_Status* status = TF_NewStatus();
225+
TFE_ContextOptions* opts = TFE_NewContextOptions();
226+
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
227+
status);
228+
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
229+
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
230+
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
231+
TFE_Context* ctx = TFE_NewContext(opts, status);
232+
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
233+
TFE_DeleteContextOptions(opts);
234+
235+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
236+
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
237+
const char remote_device_name[] =
238+
"/job:localhost/replica:0/task:1/device:CPU:0";
239+
240+
// Handles are on task0, but op is on remote (task1).
241+
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
242+
TFE_OpSetDevice(matmul, remote_device_name, status);
243+
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
244+
245+
TFE_TensorHandle* retvals[1];
246+
int num_retvals = 1;
247+
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
248+
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
249+
250+
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
251+
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
252+
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
253+
254+
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
255+
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
256+
TFE_DeleteTensorHandle(retval_task0);
257+
float product[4] = {0};
258+
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
259+
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
260+
TF_DeleteTensor(t);
261+
EXPECT_EQ(7, product[0]);
262+
EXPECT_EQ(10, product[1]);
263+
EXPECT_EQ(15, product[2]);
264+
EXPECT_EQ(22, product[3]);
265+
266+
TFE_DeleteTensorHandle(h0_task0);
267+
TFE_DeleteTensorHandle(h1_task0);
268+
TFE_DeleteTensorHandle(retvals[0]);
269+
270+
TFE_DeleteOp(matmul);
271+
272+
TFE_ContextAsyncWait(ctx, status);
273+
TFE_DeleteContext(ctx, status);
274+
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
275+
276+
TF_DeleteStatus(status);
277+
278+
// TODO(nareshmodi): Figure out how to correctly shut the server down.
279+
worker_server.release();
280+
}
281+
282+
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
283+
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
284+
TestRemoteExecuteSilentCopies(true);
285+
}
286+
208287
TEST(CAPI, TensorHandle) {
209288
TFE_TensorHandle* h = TestMatrixTensorHandle();
210289
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));

tensorflow/compiler/xla/rpc/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ tf_cc_binary(
4242
"//tensorflow/compiler/xla/service:cpu_plugin",
4343
"//tensorflow/core:framework_internal",
4444
"//tensorflow/core:lib",
45-
"@grpc//:grpc++_unsecure",
45+
"@grpc//:grpc++",
4646
],
4747
)
4848

@@ -61,7 +61,7 @@ tf_cc_test(
6161
"//tensorflow/core:lib",
6262
"//tensorflow/core:test",
6363
"//tensorflow/core:test_main",
64-
"@grpc//:grpc++_unsecure",
64+
"@grpc//:grpc++",
6565
],
6666
)
6767

@@ -74,6 +74,6 @@ cc_library(
7474
"//tensorflow/compiler/xla/service",
7575
"//tensorflow/compiler/xla/service:platform_util",
7676
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
77-
"@grpc//:grpc++_unsecure",
77+
"@grpc//:grpc++",
7878
],
7979
)

tensorflow/compiler/xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,6 @@ cc_library(
23792379
":hlo_graph_dumper",
23802380
":hlo_pass",
23812381
"//tensorflow/compiler/xla:types",
2382-
"//tensorflow/compiler/xla:util",
23832382
"//tensorflow/core:lib",
23842383
],
23852384
)
@@ -2574,6 +2573,7 @@ cc_library(
25742573
hdrs = ["hlo_graph_dumper.h"],
25752574
deps = [
25762575
":hlo",
2576+
":hlo_casting_utils",
25772577
":hlo_execution_profile",
25782578
":hlo_tfgraph_builder",
25792579
"//tensorflow/compiler/xla:literal_util",

tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
4747
element_instr = fused_expression_root;
4848
}
4949
}
50+
// Special handling of kReduce instructions -- the fusion
51+
// applies to the first operand.
52+
if (element_instr->opcode() == HloOpcode::kReduce) {
53+
return element_instr->operand(0)->shape();
54+
}
5055
return element_instr->shape();
5156
};
5257

5358
// The elementwise output shapes must be the same (including layout)
54-
return ShapeUtil::ShapeUtil::Equal(get_element_shape(instr1),
55-
get_element_shape(instr2));
59+
return ShapeUtil::Equal(get_element_shape(instr1), get_element_shape(instr2));
5660
}
5761

5862
bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {

tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ const char kModulePrefix[] = R"(
3636
scalar_lhs = f32[] parameter(0)
3737
scalar_rhs = f32[] parameter(1)
3838
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
39+
}
40+
scalar_mul_computation {
41+
scalar_lhs = f32[] parameter(0)
42+
scalar_rhs = f32[] parameter(1)
43+
ROOT mul = f32[] add(scalar_lhs, scalar_rhs)
3944
})";
4045

4146
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
@@ -67,6 +72,34 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
6772
op::Tuple(op::Reduce(), op::Reduce()));
6873
}
6974

75+
TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
76+
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
77+
fused_computation_1 {
78+
p1.1 = f32[6400]{0} parameter(1)
79+
mul = f32[6400]{0} multiply(p1.1, p1.1)
80+
const.1 = f32[] parameter(0)
81+
ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation
82+
}
83+
84+
fused_computation_2 {
85+
p1.2 = f32[6400]{0} parameter(1)
86+
r1 = f32[64,100]{0,1} reshape(p1.2)
87+
const.2 = f32[] parameter(0)
88+
ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
89+
}
90+
91+
ENTRY entry {
92+
p0 = f32[] parameter(0)
93+
p1 = f32[6400]{0} parameter(1)
94+
const.2 = f32[] constant(1)
95+
fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
96+
fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
97+
ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
98+
})"))
99+
.ValueOrDie();
100+
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
101+
}
102+
70103
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
71104
// Two sibling fusions with reduce instruction roots sharing the same input
72105
// param.

tensorflow/compiler/xla/service/hlo_computation.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
357357
std::list<HloInstruction*> post_order;
358358
std::list<HloInstruction*> trace_instructions;
359359
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
360-
std::vector<HloInstruction> dfs_stack;
361360
for (auto& instruction : instructions_) {
362361
if (instruction->opcode() == HloOpcode::kTrace) {
363362
// Trace instructions aren't handled by the DFS visitor. Add trace

tensorflow/compiler/xla/service/hlo_graph_dumper.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ limitations under the License.
2828

2929
#include "tensorflow/compiler/xla/layout_util.h"
3030
#include "tensorflow/compiler/xla/literal_util.h"
31+
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32+
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
3133
#include "tensorflow/compiler/xla/service/hlo_module.h"
3234
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
3335
#include "tensorflow/compiler/xla/shape_util.h"
@@ -723,17 +725,14 @@ string HloDotDumper::DumpRootTag() {
723725
to_id, node_body, node_shape, NodeColorAttributes(color));
724726
}
725727

726-
static const HloInstruction* TryGetFusionParameterConstant(
728+
static const HloConstantInstruction* TryGetFusionParameterConstant(
727729
const HloInstruction* instr) {
728730
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
729731
return nullptr;
730732
}
731733
const HloInstruction* fusion = instr->parent()->FusionInstruction();
732734
const HloInstruction* operand = fusion->operand(instr->parameter_number());
733-
if (operand->opcode() == HloOpcode::kConstant) {
734-
return operand;
735-
}
736-
return nullptr;
735+
return DynCast<HloConstantInstruction>(operand);
737736
}
738737

739738
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
@@ -826,7 +825,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
826825

827826
string HloDotDumper::GetInstructionNodeInlinedOperands(
828827
const HloInstruction* instr) {
829-
auto stringify_constant = [](const HloInstruction* constant) {
828+
auto stringify_constant = [](const HloConstantInstruction* constant) {
830829
const auto& shape = constant->shape();
831830

832831
// If the shape has a dimension of size zero, print it as e.g.
@@ -845,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
845844
*elem_count *= dim;
846845
}
847846
}
848-
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
847+
if (elem_count.has_value() && *elem_count <= 8) {
849848
return Printf("%s (%s)", constant->literal().ToString(),
850849
ShapeUtil::HumanString(constant->shape()));
851850
}
@@ -864,17 +863,18 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
864863
std::vector<string> lines;
865864
for (int64 i = 0; i < instr->operand_count(); ++i) {
866865
const HloInstruction* operand = instr->operand(i);
866+
const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
867867
optional<string> operand_str;
868-
if (operand->opcode() == HloOpcode::kConstant) {
869-
operand_str = stringify_constant(operand);
868+
if (constant_operand != nullptr) {
869+
operand_str = stringify_constant(constant_operand);
870870
} else if (ShouldMergeIntoUsers(operand)) {
871871
// Special case: If the operand is a parameter to a fusion node and it
872872
// always has a constant value, display it like a regular constant.
873873
//
874874
// For other parameters, use the parameter number rather than the proper
875875
// name, because that's generally how people think of the node.
876876
if (operand->opcode() == HloOpcode::kParameter) {
877-
if (const HloInstruction* constant =
877+
if (const HloConstantInstruction* constant =
878878
TryGetFusionParameterConstant(operand)) {
879879
operand_str = stringify_constant(constant);
880880
} else {

0 commit comments

Comments
 (0)