Skip to content

Commit e24bfba

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
[Caffe2] Build perfkernels targeting SVE128
Summary: We are now building perfkernels using SVE/Neon enhancements Test Plan: Sigrid Predictor canary Rollback Plan: Differential Revision: D78902495
1 parent e10797e commit e24bfba

File tree

9 files changed

+109
-72
lines changed

9 files changed

+109
-72
lines changed

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,59 @@ struct VecReduceAllSIMD<float, Op> {
102102
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
103103
// !defined(C10_MOBILE)
104104

105-
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
106-
!defined(CPU_CAPABILITY_SVE)
105+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
106+
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE512)
107107
template <typename Op>
108108
struct VecReduceAllSIMD<float, Op> {
109109
static inline float apply(
110110
const Op& vec_fun,
111111
const Vectorized<float>& acc_vec) {
112112
using Vec = Vectorized<float>;
113113
Vec v = acc_vec;
114+
// 128-bit shuffle
115+
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
116+
Vec v1 = svtbl_f32(v, ind);
117+
v = vec_fun(v, v1);
118+
// 64-bit shuffle
119+
ind = svdupq_n_u32(2, 3, 0, 1);
120+
v1 = svtbl_f32(v, ind);
121+
v = vec_fun(v, v1);
122+
// 32-bit shuffle
123+
ind = svdupq_n_u32(1, 0, 2, 3);
124+
v1 = svtbl_f32(v, ind);
125+
v = vec_fun(v, v1);
126+
return svlasta(svpfalse(), v);
127+
}
128+
};
129+
#else
130+
template <typename Op>
131+
struct VecReduceAllSIMD<float, Op> {
132+
static inline float apply(
133+
const Op& vec_fun,
134+
const Vectorized<float>& acc_vec) {
135+
using Vec = Vectorized<float>;
136+
Vec v = acc_vec;
137+
float32x4_t vReg = vld1q_f32(reinterpret_cast<const float*>(acc_vec.as_bytes()));
114138

115139
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
116140
// a4+a8, a1+a5, a2+a6, -, -, -, -]
117-
float32x4_t v1_1 = vextq_f32(v, v, 2);
118-
Vec v1 = v1_1;
141+
float32x4_t v1_1 = vextq_f32(vReg, vReg, 2);
142+
143+
__at_align__ float v1[4];
144+
vst1q_f32(reinterpret_cast<float*>(&v1), v1_1);
119145
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
120-
v = vec_fun(v, v1);
146+
at::vec::Vectorized<float> vf1(v1);
147+
v = vec_fun(v, vf1);
121148

122149
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
123150
// -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
124151
// -]
125-
v1_1 = vrev64q_f32(v);
126-
v1 = v1_1;
152+
v1_1 = vrev64q_f32(vld1q_f32(reinterpret_cast<const float*>(acc_vec.as_bytes())));
153+
vst1q_f32(reinterpret_cast<float*>(&v1), v1_1);
127154
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
128155
// a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
129-
v = vec_fun(v, v1);
156+
at::vec::Vectorized<float> vf2(v1);
157+
v = vec_fun(v, vf2);
130158

131159
return v[0];
132160
}
@@ -137,38 +165,13 @@ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
137165
static inline float apply(
138166
const std::plus<Vectorized<float>>& vec_fun,
139167
const Vectorized<float>& acc_vec) {
140-
return vaddvq_f32(acc_vec);
168+
return vaddvq_f32(vld1q_f32(reinterpret_cast<const float*>(acc_vec.as_bytes())));
141169
}
142170
};
171+
#endif // defined(CPU_CAPABILITY_SVE256)
143172
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
144-
// && !defined(CPU_CAPABILITY_SVE)
145173

146-
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
147-
defined(CPU_CAPABILITY_SVE256)
148-
template <typename Op>
149-
struct VecReduceAllSIMD<float, Op> {
150-
static inline float apply(
151-
const Op& vec_fun,
152-
const Vectorized<float>& acc_vec) {
153-
using Vec = Vectorized<float>;
154-
Vec v = acc_vec;
155-
// 128-bit shuffle
156-
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
157-
Vec v1 = svtbl_f32(v, ind);
158-
v = vec_fun(v, v1);
159-
// 64-bit shuffle
160-
ind = svdupq_n_u32(2, 3, 0, 1);
161-
v1 = svtbl_f32(v, ind);
162-
v = vec_fun(v, v1);
163-
// 32-bit shuffle
164-
ind = svdupq_n_u32(1, 0, 2, 3);
165-
v1 = svtbl_f32(v, ind);
166-
v = vec_fun(v, v1);
167-
return svlasta(svpfalse(), v);
168-
}
169-
};
170-
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
171-
// && defined(CPU_CAPABILITY_SVE256)
174+
172175

173176
template <typename scalar_t, typename Op>
174177
inline scalar_t vec_reduce_all(

aten/src/ATen/cpu/vec/sve/vec_common_sve.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
#include <ATen/cpu/vec/sve/vec_double.h>
1313
#include <ATen/cpu/vec/sve/vec_float.h>
1414
#include <ATen/cpu/vec/sve/vec_int.h>
15-
#include <ATen/cpu/vec/sve/vec_qint.h>
15+
//SVE qint version is broken
16+
#include <ATen/cpu/vec/vec256/vec256_qint.h>
1617
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
1718

1819
namespace at::vec {
@@ -74,12 +75,6 @@ DEFINE_SVE_CAST(int32_t, s32, float, f32)
7475
DEFINE_SVE_CAST(int16_t, s16, float, f32)
7576
DEFINE_SVE_CAST(float, f32, double, f64)
7677

77-
#ifdef __ARM_FEATURE_BF16
78-
DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16)
79-
DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16)
80-
DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16)
81-
#endif // __ARM_FEATURE_BF16
82-
8378
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8479

8580
template<int64_t scale = 1>
@@ -183,9 +178,13 @@ std::pair<
183178
// group cols crossing lanes:
184179
// return {a0, b0, a1, b1, a2, b2, a3, b3}
185180
// {a4, b4, a5, b5, a6, b6, a7, b7}
186-
return std::make_pair(
187-
Vectorized<c10::BFloat16>(svzip1_bf16(a, b)),
188-
Vectorized<c10::BFloat16>(svzip2_bf16(a, b)));
181+
Vectorized<c10::BFloat16> c;
182+
Vectorized<c10::BFloat16> d;
183+
svbfloat16_t aReg = svreinterpret_bf16_u64(a.getSve());
184+
svbfloat16_t bReg = svreinterpret_bf16_u64(b.getSve());
185+
c.setSve(svreinterpret_u64_bf16(svzip1_bf16(aReg, bReg)));
186+
d.setSve(svreinterpret_u64_bf16(svzip2_bf16(aReg, bReg)));
187+
return std::make_pair(c, d);
189188
}
190189
#endif // __ARM_FEATURE_BF16
191190

@@ -234,9 +233,13 @@ std::pair<
234233
// swap lanes:
235234
// return {a0, a1, a2, a3, a4, a5, a6, a7}
236235
// {b0, b1, b2, b3, b4, b5, b6, b7}
237-
return std::make_pair(
238-
Vectorized<c10::BFloat16>(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)),
239-
Vectorized<c10::BFloat16>(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b)));
236+
Vectorized<c10::BFloat16> c;
237+
Vectorized<c10::BFloat16> d;
238+
svbfloat16_t aReg = svreinterpret_bf16_u64(a.getSve());
239+
svbfloat16_t bReg = svreinterpret_bf16_u64(b.getSve());
240+
c.setSve(svreinterpret_u64_bf16(svuzp1_bf16(aReg, bReg)));
241+
d.setSve(svreinterpret_u64_bf16(svuzp2_bf16(aReg, bReg)));
242+
return std::make_pair(c, d);
240243
}
241244
#endif // __ARM_FEATURE_BF16
242245

aten/src/ATen/cpu/vec/vec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#if defined(CPU_CAPABILITY_AVX512)
44
#include <ATen/cpu/vec/vec512/vec512.h>
5-
#elif defined(CPU_CAPABILITY_SVE)
5+
#elif defined(__aarch64__)
66
#include <ATen/cpu/vec/sve/vec_common_sve.h>
77
#else
88
#include <ATen/cpu/vec/vec128/vec128.h>

aten/src/ATen/cpu/vec/vec256/vec256_qint.h

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,48 +1375,64 @@ Vectorized<c10::quint8> inline maximum(
13751375
#endif // if defined(CPU_CAPABILITY_AVX2)
13761376

13771377
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
1378-
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
1378+
std::pair<at::vec::Vectorized<float>, at::vec::Vectorized<float>> inline convert_int8_to_float(
13791379
at::vec::Vectorized<int8_t> src) {
13801380
auto s8x8 = vld1_s8(src.operator const int8_t*());
13811381
auto s16x8 = vmovl_s8(s8x8);
13821382

13831383
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
13841384
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
13851385

1386-
return std::make_pair(
1387-
Vectorized<float>(vcvtq_f32_s32(s32x4_lo)),
1388-
Vectorized<float>(vcvtq_f32_s32(s32x4_hi)));
1386+
__at_align__ float a[4];
1387+
__at_align__ float b[4];
1388+
1389+
vst1q_f32(reinterpret_cast<float*>(&a), vcvtq_f32_s32(s32x4_lo));
1390+
vst1q_f32(reinterpret_cast<float*>(&b), vcvtq_f32_s32(s32x4_hi));
1391+
1392+
return std::make_pair(at::vec::Vectorized<float>(a), at::vec::Vectorized<float>(b));
13891393
}
13901394

1391-
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
1395+
std::pair<at::vec::Vectorized<float>, at::vec::Vectorized<float>> inline convert_int8_to_float(
13921396
at::vec::Vectorized<uint8_t> src) {
13931397
auto u8x8 = vld1_u8(src.operator const uint8_t*());
13941398
auto u16x8 = vmovl_u8(u8x8);
13951399
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
13961400
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
13971401

1398-
return std::make_pair(
1399-
Vectorized<float>(vcvtq_f32_u32(u32x4_lo)),
1400-
Vectorized<float>(vcvtq_f32_u32(u32x4_hi)));
1402+
__at_align__ float a[4];
1403+
__at_align__ float b[4];
1404+
1405+
vst1q_f32(reinterpret_cast<float*>(&a), vcvtq_f32_u32(u32x4_lo));
1406+
vst1q_f32(reinterpret_cast<float*>(&b), vcvtq_f32_u32(u32x4_hi));
1407+
1408+
return std::make_pair(at::vec::Vectorized<float>(a), at::vec::Vectorized<float>(b));
14011409
}
14021410

1403-
Vectorized<float> inline convert_int8_half_register_to_float(
1411+
at::vec::Vectorized<float> inline convert_int8_half_register_to_float(
14041412
at::vec::Vectorized<int8_t> src) {
14051413
auto s8x8 = vld1_s8(src.operator const int8_t*());
14061414
auto s16x8 = vmovl_s8(s8x8);
14071415

14081416
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
14091417

1410-
return Vectorized<float>(vcvtq_f32_s32(s32x4_lo));
1418+
__at_align__ float r[4];
1419+
1420+
vst1q_f32(reinterpret_cast<float*>(&r), vcvtq_f32_s32(s32x4_lo));
1421+
1422+
return at::vec::Vectorized<float>(r);
14111423
}
14121424

1413-
Vectorized<float> inline convert_int8_half_register_to_float(
1425+
at::vec::Vectorized<float> inline convert_int8_half_register_to_float(
14141426
at::vec::Vectorized<uint8_t> src) {
14151427
auto u8x8 = vld1_u8(src.operator const uint8_t*());
14161428
auto u16x8 = vmovl_u8(u8x8);
14171429
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
14181430

1419-
return Vectorized<float>(vcvtq_f32_u32(u32x4_lo));
1431+
__at_align__ float r[4];
1432+
1433+
vst1q_f32(reinterpret_cast<float*>(&r), vcvtq_f32_u32(u32x4_lo));
1434+
1435+
return at::vec::Vectorized<float>(r);
14201436
}
14211437

14221438
#endif

aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ dot_with_fp32_arith_main_inner_loop_bfdot(
187187
const auto temp_vec2 = vld1q_bf16(
188188
reinterpret_cast<const bfloat16_t*>(
189189
&vec2[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
190-
sum[registerPairIndex] =
191-
vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2);
190+
sum[registerPairIndex].setNeon(vreinterpretq_p128_f32(vbfdotq_f32(vreinterpretq_f32_p128(sum[registerPairIndex].getNeon()), temp_vec1, temp_vec2)));
192191
}
193192

194193
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
@@ -200,7 +199,7 @@ void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
200199
// See NOTE[Intrinsics in bfdot variant] above.
201200
const auto temp_vec1 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec1[idx]));
202201
const auto temp_vec2 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec2[idx]));
203-
*tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2);
202+
tail_sum->setNeon(vreinterpretq_p128_f32(vbfdotq_f32(vreinterpretq_f32_p128(tail_sum->getNeon()), temp_vec1, temp_vec2)));
204203
}
205204

206205
#else
@@ -214,8 +213,14 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
214213
const vec::Vectorized<c10::Half>& b,
215214
const vec::Vectorized<float>& acc_low,
216215
const vec::Vectorized<float>& acc_high) {
217-
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
218-
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
216+
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
217+
float16x8_t aReg = vreinterpretq_f16_p128(a.getNeon());
218+
float16x8_t bReg = vreinterpretq_f16_p128(b.getNeon());
219+
vec::Vectorized<float> c;
220+
vec::Vectorized<float> d;
221+
c.setNeon(vreinterpretq_p128_f32(vfmlalq_low_f16(vreinterpretq_f32_p128(acc_low.getNeon()), aReg, bReg)));
222+
d.setNeon(vreinterpretq_p128_f32(vfmlalq_high_f16(vreinterpretq_f32_p128(acc_high.getNeon()), aReg, bReg)));
223+
return std::make_pair(c, d);
219224
#else
220225
const auto [a_float_low, a_float_high] = convert_half_float(a);
221226
const auto [b_float_low, b_float_high] = convert_half_float(b);
@@ -235,13 +240,18 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
235240

236241
// Return a + b_low * c_low + b_high * c_high
237242
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
238-
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)
243+
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
239244
// NOTE: this instruction is an optional instruction in ARM v8.2 and
240245
// v8.3, but mandatory in v8.4 per
241246
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
242247
// I'm not certain that I have the right feature test macro.
243-
vec::Vectorized<float> first = vfmlalq_low_f16(a, b, c);
244-
return vfmlalq_high_f16(first, b, c);
248+
float32x4_t aReg = vreinterpretq_f32_p128(a.getNeon());
249+
float32x4_t bReg = vreinterpretq_f32_p128(b.getNeon());
250+
float16x8_t cReg = vreinterpretq_f16_p128(c.getNeon());
251+
vec::Vectorized<float> res;
252+
res.setNeon(vreinterpretq_p128_f32(vfmlalq_low_f16(aReg, bReg, cReg)));
253+
res.setNeon(vreinterpretq_p128_f32(vfmlalq_high_f16(vreinterpretq_f32_p128(res.getNeon()), bReg, cReg)));
254+
return res;
245255
#else
246256
const auto [b_float_low, b_float_high] = convert_half_float(b);
247257
const auto [c_float_low, c_float_high] = convert_half_float(c);

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
#if defined(__ARM_NEON__) || defined(__aarch64__)
3535
#include <ATen/quantized/Quantizer.h>
3636
#include <arm_neon.h>
37+
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE128)
38+
#include <ATen/cpu/vec/sve/vec_common_sve.h>
39+
#endif
3740
#endif
3841

3942

aten/src/ATen/native/quantized/cpu/qrelu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
#include <ATen/ops/relu_native.h>
2727
#endif
2828

29+
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE128)
30+
#include <ATen/cpu/vec/sve/vec_common_sve.h>
31+
#endif
32+
2933
#include <algorithm>
3034

3135
namespace at::native {

torch/csrc/jit/runtime/static/ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <ATen/Parallel.h>
77
#include <ATen/ScalarOps.h>
88
#include <ATen/TensorUtils.h>
9-
#include <ATen/cpu/vec/functional.h>
109
#include <ATen/cpu/vec/vec.h>
1110
#include <ATen/native/Fill.h>
1211
#include <ATen/native/IndexingUtils.h>

torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <ATen/Parallel.h>
77
#include <ATen/ScalarOps.h>
88
#include <ATen/TensorUtils.h>
9-
#include <ATen/cpu/vec/functional.h>
109
#include <ATen/cpu/vec/vec.h>
1110
#include <ATen/native/EmbeddingBag.h>
1211
#include <ATen/native/Fill.h>

0 commit comments

Comments
 (0)