@@ -36,6 +36,11 @@ const char kModulePrefix[] = R"(
36
36
scalar_lhs = f32[] parameter(0)
37
37
scalar_rhs = f32[] parameter(1)
38
38
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
39
+ }
40
+ scalar_mul_computation {
41
+ scalar_lhs = f32[] parameter(0)
42
+ scalar_rhs = f32[] parameter(1)
43
+ ROOT mul = f32[] add(scalar_lhs, scalar_rhs)
39
44
})" ;
40
45
41
46
TEST_F (InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
@@ -67,6 +72,34 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
67
72
op::Tuple (op::Reduce (), op::Reduce ()));
68
73
}
69
74
75
+ TEST_F (InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
76
+ auto module = ParseHloString (tensorflow::strings::StrCat (kModulePrefix , R"(
77
+ fused_computation_1 {
78
+ p1.1 = f32[6400]{0} parameter(1)
79
+ mul = f32[6400]{0} multiply(p1.1, p1.1)
80
+ const.1 = f32[] parameter(0)
81
+ ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation
82
+ }
83
+
84
+ fused_computation_2 {
85
+ p1.2 = f32[6400]{0} parameter(1)
86
+ r1 = f32[64,100]{0,1} reshape(p1.2)
87
+ const.2 = f32[] parameter(0)
88
+ ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
89
+ }
90
+
91
+ ENTRY entry {
92
+ p0 = f32[] parameter(0)
93
+ p1 = f32[6400]{0} parameter(1)
94
+ const.2 = f32[] constant(1)
95
+ fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
96
+ fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
97
+ ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
98
+ })" ))
99
+ .ValueOrDie ();
100
+ ASSERT_FALSE (GpuMultiOutputFusion ().Run (module .get ()).ValueOrDie ());
101
+ }
102
+
70
103
TEST_F (InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
71
104
// Two sibling fusions with reduce instruction roots sharing the same input
72
105
// param.
0 commit comments