Skip to content

Commit f9ae897

Browse files
d0ktensorflower-gardener
authored andcommitted
[XLA:GPU] Check the reduce input shape when multi-output fusing reduces
Otherwise we can end up in a situation where incompatible reduces that happen to have the same output shape are fused. PiperOrigin-RevId: 200180013
1 parent 51f2b9e commit f9ae897

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
4747
element_instr = fused_expression_root;
4848
}
4949
}
50+
// Special handling of kReduce instructions -- the fusion
51+
// applies to the first operand.
52+
if (element_instr->opcode() == HloOpcode::kReduce) {
53+
return element_instr->operand(0)->shape();
54+
}
5055
return element_instr->shape();
5156
};
5257

5358
// The elementwise output shapes must be the same (including layout)
54-
return ShapeUtil::ShapeUtil::Equal(get_element_shape(instr1),
55-
get_element_shape(instr2));
59+
return ShapeUtil::Equal(get_element_shape(instr1), get_element_shape(instr2));
5660
}
5761

5862
bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {

tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ const char kModulePrefix[] = R"(
3636
scalar_lhs = f32[] parameter(0)
3737
scalar_rhs = f32[] parameter(1)
3838
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
39+
}
40+
scalar_mul_computation {
41+
scalar_lhs = f32[] parameter(0)
42+
scalar_rhs = f32[] parameter(1)
43+
ROOT mul = f32[] add(scalar_lhs, scalar_rhs)
3944
})";
4045

4146
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
@@ -67,6 +72,34 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
6772
op::Tuple(op::Reduce(), op::Reduce()));
6873
}
6974

75+
TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
76+
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
77+
fused_computation_1 {
78+
p1.1 = f32[6400]{0} parameter(1)
79+
mul = f32[6400]{0} multiply(p1.1, p1.1)
80+
const.1 = f32[] parameter(0)
81+
ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation
82+
}
83+
84+
fused_computation_2 {
85+
p1.2 = f32[6400]{0} parameter(1)
86+
r1 = f32[64,100]{0,1} reshape(p1.2)
87+
const.2 = f32[] parameter(0)
88+
ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
89+
}
90+
91+
ENTRY entry {
92+
p0 = f32[] parameter(0)
93+
p1 = f32[6400]{0} parameter(1)
94+
const.2 = f32[] constant(1)
95+
fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
96+
fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
97+
ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
98+
})"))
99+
.ValueOrDie();
100+
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
101+
}
102+
70103
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
71104
// Two sibling fusions with reduce instruction roots sharing the same input
72105
// param.

0 commit comments

Comments
 (0)