Skip to content

Commit c1f3a2d

Browse files
committed
[PyTorch] Unbreak VectorizedN fmadd/fmsub/clamp
These are ternary ops, not binary ops. Differential Revision: [D64794253](https://our.internmc.facebook.com/intern/diff/D64794253/) [ghstack-poisoned]
1 parent 7eeef66 commit c1f3a2d

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

aten/src/ATen/cpu/vec/vec_n.h

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ class VectorizedN {
7777
return result;
7878
}
7979

80+
template <typename Op>
81+
inline VectorizedN<T, N> ternary_op(
82+
const VectorizedN<T, N>& other,
83+
const VectorizedN<T, N>& other2,
84+
Op op) const {
85+
VectorizedN<T, N> result;
86+
#ifndef _MSC_VER
87+
#pragma unroll
88+
#endif
89+
for (int i = 0; i < N; ++i) {
90+
result.values[i] = op(values[i], other.values[i], other2.values[i]);
91+
}
92+
return result;
93+
}
94+
8095
VectorizedN() = default;
8196

8297
explicit VectorizedN(T val) {
@@ -89,7 +104,8 @@ class VectorizedN {
89104
VectorizedN(const Vectorized<T>& val) : values({val}) {}
90105

91106
template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
92-
VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
107+
VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1)
108+
: values({val_0, val_1}) {}
93109

94110
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
95111
inline operator Vectorized<T>() const {
@@ -110,7 +126,8 @@ class VectorizedN {
110126
const VectorizedN<T, N>& b) {
111127
VectorizedN<T, N> result;
112128
for (int i = 0; i < N; ++i) {
113-
result.values[i] = Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
129+
result.values[i] =
130+
Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
114131
}
115132
return result;
116133
}
@@ -306,6 +323,20 @@ class VectorizedN {
306323
}); \
307324
}
308325

326+
#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \
327+
template <typename T, int N> \
328+
inline VectorizedN<T, N> op( \
329+
const VectorizedN<T, N>& a, \
330+
const VectorizedN<T, N>& b, \
331+
const VectorizedN<T, N>& c) { \
332+
return a.ternary_op( \
333+
b, \
334+
c, \
335+
[](const Vectorized<T>& a, \
336+
const Vectorized<T>& b, \
337+
const Vectorized<T>& c) { return op(a, b, c); }); \
338+
}
339+
309340
#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
310341
template <typename T, int N> \
311342
inline VectorizedN<T, N>& op( \
@@ -326,9 +357,9 @@ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
326357
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
327358
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
328359
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
329-
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
330-
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
331-
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
360+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd)
361+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub)
362+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp)
332363
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
333364
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
334365
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
@@ -357,5 +388,17 @@ inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
357388
return vec_reduce_all(vec_fun, vec_result);
358389
}
359390

391+
template <typename T, int N>
392+
std::ostream& operator<<(std::ostream& stream, const VectorizedN<T, N>& vec_n) {
393+
stream << "vec_n[";
394+
for (int i = 0; i < N; ++i) {
395+
if (i != 0) {
396+
stream << ", ";
397+
}
398+
stream << vec_n[i];
399+
}
400+
stream << ']';
401+
return stream;
402+
}
360403
} // namespace CPU_CAPABILITY
361404
} // namespace at::vec

aten/src/ATen/test/vec_test_all_types.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,17 @@ namespace {
821821
createDefaultTernaryTestCase<vec>(TestSeed()),
822822
RESOLVE_OVERLOAD(filter_clamp));
823823
}
824+
TYPED_TEST(MinMax, ClampVecN) {
825+
using VT = ValueType<TypeParam>;
826+
using vec = at::vec::VectorizedN<VT, 1>;
827+
test_ternary<vec>(
828+
NAME_INFO(clamp), clamp<VT>,
829+
[](const vec& v0, const vec& v1, const vec& v2) {
830+
return clamp(v0, v1, v2);
831+
},
832+
createDefaultTernaryTestCase<vec>(TestSeed()),
833+
RESOLVE_OVERLOAD(filter_clamp));
834+
}
824835
TYPED_TEST(BitwiseFloatsAdditional, ZeroMask) {
825836
using vec = TypeParam;
826837
using VT = ValueType<TypeParam>;
@@ -895,7 +906,25 @@ namespace {
895906
.setTestSeed(TestSeed());
896907

897908
test_ternary<vec>(
898-
NAME_INFO(clamp), RESOLVE_OVERLOAD(local_fmadd),
909+
NAME_INFO(fmadd), RESOLVE_OVERLOAD(local_fmadd),
910+
[](const vec& v0, const vec& v1, const vec& v2) {
911+
return at::vec::fmadd(v0, v1, v2);
912+
},
913+
test_case,
914+
RESOLVE_OVERLOAD(filter_fmadd));
915+
}
916+
TYPED_TEST(BitwiseFloatsAdditional, FmaddVecN) {
917+
using VT = ValueType<TypeParam>;
918+
using vec = at::vec::VectorizedN<VT, 1>;
919+
920+
auto test_case = TestingCase<vec>::getBuilder()
921+
.addDomain(CheckWithinDomains<VT>{
922+
{{(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}},
923+
true, getDefaultTolerance<VT>()})
924+
.setTestSeed(TestSeed());
925+
926+
test_ternary<vec>(
927+
NAME_INFO(fmadd), RESOLVE_OVERLOAD(local_fmadd),
899928
[](const vec& v0, const vec& v1, const vec& v2) {
900929
return at::vec::fmadd(v0, v1, v2);
901930
},

0 commit comments

Comments
 (0)