@@ -77,6 +77,21 @@ class VectorizedN {
77
77
return result;
78
78
}
79
79
80
+ template <typename Op>
81
+ inline VectorizedN<T, N> ternary_op (
82
+ const VectorizedN<T, N>& other,
83
+ const VectorizedN<T, N>& other2,
84
+ Op op) const {
85
+ VectorizedN<T, N> result;
86
+ #ifndef _MSC_VER
87
+ #pragma unroll
88
+ #endif
89
+ for (int i = 0 ; i < N; ++i) {
90
+ result.values [i] = op (values[i], other.values [i], other2.values [i]);
91
+ }
92
+ return result;
93
+ }
94
+
80
95
VectorizedN () = default ;
81
96
82
97
explicit VectorizedN (T val) {
@@ -89,7 +104,8 @@ class VectorizedN {
89
104
VectorizedN (const Vectorized<T>& val) : values({val}) {}
90
105
91
106
template <int L = N, typename std::enable_if_t <L == 2 , int > = 0 >
92
- VectorizedN (const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
107
+ VectorizedN (const Vectorized<T>& val_0, const Vectorized<T>& val_1)
108
+ : values({val_0, val_1}) {}
93
109
94
110
template <int L = N, typename std::enable_if_t <L == 1 , int > = 0 >
95
111
inline operator Vectorized<T>() const {
@@ -110,7 +126,8 @@ class VectorizedN {
110
126
const VectorizedN<T, N>& b) {
111
127
VectorizedN<T, N> result;
112
128
for (int i = 0 ; i < N; ++i) {
113
- result.values [i] = Vectorized<T>::template blend<mask>(a.values [i], b.values [i]);
129
+ result.values [i] =
130
+ Vectorized<T>::template blend<mask>(a.values [i], b.values [i]);
114
131
}
115
132
return result;
116
133
}
@@ -306,6 +323,20 @@ class VectorizedN {
306
323
}); \
307
324
}
308
325
326
+ #define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (op ) \
327
+ template <typename T, int N> \
328
+ inline VectorizedN<T, N> op ( \
329
+ const VectorizedN<T, N>& a, \
330
+ const VectorizedN<T, N>& b, \
331
+ const VectorizedN<T, N>& c) { \
332
+ return a.ternary_op ( \
333
+ b, \
334
+ c, \
335
+ [](const Vectorized<T>& a, \
336
+ const Vectorized<T>& b, \
337
+ const Vectorized<T>& c) { return op (a, b, c); }); \
338
+ }
339
+
309
340
#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL (op ) \
310
341
template <typename T, int N> \
311
342
inline VectorizedN<T, N>& op ( \
@@ -326,9 +357,9 @@ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
326
357
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (operator >>)
327
358
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (maximum)
328
359
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (minimum)
329
- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (fmadd)
330
- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (fmsub)
331
- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp)
360
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (fmadd)
361
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (fmsub)
362
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (clamp)
332
363
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp_max)
333
364
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp_min)
334
365
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (operator &)
@@ -357,5 +388,17 @@ inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
357
388
return vec_reduce_all (vec_fun, vec_result);
358
389
}
359
390
391
+ template <typename T, int N>
392
+ std::ostream& operator <<(std::ostream& stream, const VectorizedN<T, N>& vec_n) {
393
+ stream << " vec_n[" ;
394
+ for (int i = 0 ; i < N; ++i) {
395
+ if (i != 0 ) {
396
+ stream << " , " ;
397
+ }
398
+ stream << vec_n[i];
399
+ }
400
+ stream << ' ]' ;
401
+ return stream;
402
+ }
360
403
} // namespace CPU_CAPABILITY
361
404
} // namespace at::vec
0 commit comments