diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 7239f357fdd6..5c4ed5215c1d 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -103,9 +103,13 @@ std::string get_cpu_capability() { #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) case native::CPUCapability::ZVECTOR: return "Z VECTOR"; -#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION) +#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION) + case native::CPUCapability::SVE128: + return "SVE128"; case native::CPUCapability::SVE256: return "SVE256"; + case native::CPUCapability::SVE512: + return "SVE512"; #else case native::CPUCapability::AVX2: return "AVX2"; diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 112121b29705..bb7033c97c21 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -102,8 +102,31 @@ struct VecReduceAllSIMD { #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && // !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ - !defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE512) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); + Vec v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 64-bit shuffle + ind = svdupq_n_u32(2, 3, 0, 1); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 32-bit shuffle + ind = svdupq_n_u32(1, 0, 2, 3); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + return svlasta(svpfalse(), v); + } +}; +#else template struct VecReduceAllSIMD { static inline float apply( @@ -140,35 +163,10 @@ struct VecReduceAllSIMD>> { return vaddvq_f32(acc_vec); } }; +#endif // defined(CPU_CAPABILITY_SVE256) #endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) - // && !defined(CPU_CAPABILITY_SVE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ - defined(CPU_CAPABILITY_SVE256) -template -struct VecReduceAllSIMD { - static inline float apply( - const Op& vec_fun, - const Vectorized& acc_vec) { - using Vec = Vectorized; - Vec v = acc_vec; - // 128-bit shuffle - svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); - Vec v1 = svtbl_f32(v, ind); - v = vec_fun(v, v1); - // 64-bit shuffle - ind = svdupq_n_u32(2, 3, 0, 1); - v1 = svtbl_f32(v, ind); - v = vec_fun(v, v1); - // 32-bit shuffle - ind = svdupq_n_u32(1, 0, 2, 3); - v1 = svtbl_f32(v, ind); - v = vec_fun(v, v1); - return svlasta(svpfalse(), v); - } -}; -#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) - // && defined(CPU_CAPABILITY_SVE256) + template inline scalar_t vec_reduce_all( diff --git a/aten/src/ATen/cpu/vec/sve/sve_helper.h b/aten/src/ATen/cpu/vec/sve/sve_helper.h index f3786019064c..144573fb4047 100644 --- a/aten/src/ATen/cpu/vec/sve/sve_helper.h +++ b/aten/src/ATen/cpu/vec/sve/sve_helper.h @@ -1,10 +1,21 @@ #pragma once +#include +#include + #include #include -#if defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && (defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) || defined(AT_BUILD_ARM_VECSVE_WITH_SLEEF)) +#define SLEEF_STATIC_LIBS +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE128) // Define the data type of VLS(vector-length specific). typedef svbool_t vls_pred_t diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h index d269e1073959..95733228e682 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h @@ -2,10 +2,6 @@ #include #include -#include -#include -#include -#include #include namespace at { namespace vec { @@ -19,7 +15,8 @@ namespace vec { // accessed as `at::vec`. inline namespace CPU_CAPABILITY { -#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) +#if (defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE512)) && \ + defined(__ARM_FEATURE_BF16) template <> struct is_vec_specialized_for : std::bool_constant {}; @@ -39,6 +36,8 @@ class Vectorized { Vectorized(); Vectorized(svbfloat16_t v) : values(v) {} + Vectorized(float val); + Vectorized(double val); Vectorized(int val); Vectorized(BFloat16 val); @@ -586,7 +585,15 @@ Vectorized inline fmadd( return a * b + c; } -#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" + +CONVERT_NON_VECTORIZED_INIT(Half, half); +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16); + +#pragma GCC diagnostic pop + +#else // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) } // namespace CPU_CAPABILITY } // namespace vec diff --git a/aten/src/ATen/cpu/vec/sve/vec_double.h b/aten/src/ATen/cpu/vec/sve/vec_double.h index 474652be17a1..67d714c32882 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_double.h +++ b/aten/src/ATen/cpu/vec/sve/vec_double.h @@ -3,13 +3,8 @@ #include #include #include +#include #include -#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#include -#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code -#else -#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code -#endif namespace at::vec { // Note [CPU_CAPABILITY namespace] @@ -178,6 +173,180 @@ class Vectorized { Vectorized conj() const { return *this; } +#ifdef CPU_CAPABILITY_SVE128 + + Vectorized(float64x2_t v) : values(svset_neonq(svundef_f64(), v)) {} + + operator float64x2_t() const { + return svget_neonq(values); + } + + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosd2_u10(svget_neonq(values))), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshd2_u10(svget_neonq(values))), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asind2_u10(svget_neonq(values))), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhd2_u10(svget_neonq(values))), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atand2_u10(svget_neonq(values))), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhd2_u10(svget_neonq(values))), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2d2_u10(svget_neonq(values), b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} + Vectorized copysign(const Vectorized& sign) const { + USE_SLEEF( + { return Vectorized(Sleef_copysignd2(svget_neonq(values), sign)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erfd2_u10(svget_neonq(values))), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcd2_u15(svget_neonq(values))), map(std::erfc)); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expd2_u10(svget_neonq(values))), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2d2_u10(svget_neonq(values))), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1d2_u10(svget_neonq(values))), map(std::expm1)); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmodd2(svget_neonq(values), q)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} + Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotd2_u05(svget_neonq(values), b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterd2(svget_neonq(values), b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logd2_u10(svget_neonq(values))), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2d2_u10(svget_neonq(values))), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10d2_u10(svget_neonq(values))), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pd2_u10(svget_neonq(values))), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sind2_u10(svget_neonq(values))), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhd2_u10(svget_neonq(values))), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosd2_u10(svget_neonq(values))), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshd2_u10(svget_neonq(values))), map(std::cosh)); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powd2_u10(svget_neonq(values), b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tand2_u10(svget_neonq(values))), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhd2_u10(svget_neonq(values))), map(std::tanh)); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammad2_u10(svget_neonq(values))), map(std::lgamma)); + } +#else Vectorized acos() const { return USE_SLEEF( Vectorized(Sleef_acosdx_u10sve(values)), map(std::acos)); @@ -213,7 +382,8 @@ class Vectorized { tmp[i] = std::atan2(tmp[i], tmp_b[i]); } return loadu(tmp); - })} Vectorized copysign(const Vectorized& sign) const { + })} + Vectorized copysign(const Vectorized& sign) const { USE_SLEEF( { return Vectorized(Sleef_copysigndx_sve(values, sign)); }, { @@ -233,9 +403,6 @@ class Vectorized { return USE_SLEEF( Vectorized(Sleef_erfcdx_u15sve(values)), map(std::erfc)); } - Vectorized erfinv() const { - return map(calc_erfinv); - } Vectorized exp() const { return USE_SLEEF( Vectorized(Sleef_expdx_u10sve(values)), map(std::exp)); @@ -248,12 +415,6 @@ class Vectorized { return USE_SLEEF( Vectorized(Sleef_expm1dx_u10sve(values)), map(std::expm1)); } - Vectorized exp_u20() const { - return exp(); - } - Vectorized fexp_u20() const { - return exp(); - } Vectorized fmod(const Vectorized& q) const {USE_SLEEF( { return Vectorized(Sleef_fmoddx_sve(values, q)); }, { @@ -265,7 +426,8 @@ class Vectorized { tmp[i] = std::fmod(tmp[i], tmp_q[i]); } return loadu(tmp); - })} Vectorized hypot(const Vectorized& b) const { + })} + Vectorized hypot(const Vectorized& b) const { USE_SLEEF( { return Vectorized(Sleef_hypotdx_u05sve(values, b)); }, { @@ -280,32 +442,6 @@ class Vectorized { })} Vectorized i0() const { return map(calc_i0); } - Vectorized i0e() const { - return map(calc_i0e); - } - Vectorized digamma() const { - return map(calc_digamma); - } - Vectorized igamma(const Vectorized& x) const { - __at_align__ double tmp[size()]; - __at_align__ double tmp_x[size()]; - store(tmp); - x.store(tmp_x); - for (int64_t i = 0; i < size(); i++) { - tmp[i] = calc_igamma(tmp[i], tmp_x[i]); - } - return loadu(tmp); - } - Vectorized igammac(const Vectorized& x) const { - __at_align__ double tmp[size()]; - __at_align__ double tmp_x[size()]; - store(tmp); - x.store(tmp_x); - for (int64_t i = 0; i < size(); i++) { - tmp[i] = calc_igammac(tmp[i], tmp_x[i]); - } - return loadu(tmp); - } Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_nextafterdx_sve(values, b)); }, { @@ -350,6 +486,68 @@ class Vectorized { return USE_SLEEF( Vectorized(Sleef_coshdx_u10sve(values)), map(std::cosh)); } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powdx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); + } +#endif + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vectorized ceil() const { return svrintp_f64_x(ptrue, values); } @@ -362,21 +560,10 @@ class Vectorized { Vectorized round() const { return svrinti_f64_x(ptrue, values); } - Vectorized tan() const { - return USE_SLEEF( - Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); - } - Vectorized tanh() const { - return USE_SLEEF( - Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); - } + Vectorized trunc() const { return svrintz_f64_x(ptrue, values); } - Vectorized lgamma() const { - return USE_SLEEF( - Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); - } Vectorized sqrt() const { return svsqrt_f64_x(ptrue, values); } @@ -386,20 +573,7 @@ class Vectorized { Vectorized rsqrt() const { return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); } - Vectorized pow(const Vectorized& b) const {USE_SLEEF( - { return Vectorized(Sleef_powdx_u10sve(values, b)); }, - { - __at_align__ double tmp[size()]; - __at_align__ double tmp_b[size()]; - store(tmp); - b.store(tmp_b); - for (int64_t i = 0; i < size(); i++) { - tmp[i] = std::pow(tmp[i], tmp_b[i]); - } - return loadu(tmp); - })} // Comparison using the _CMP_**_OQ predicate. - // `O`: get false if an operand is NaN - // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { svbool_t mask = svcmpeq_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h index 89bce507c484..1559f0a00f74 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_float.h +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -3,13 +3,8 @@ #include #include #include +#include #include -#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#include -#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code -#else -#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code -#endif namespace at::vec { // Note [CPU_CAPABILITY namespace] @@ -22,7 +17,7 @@ namespace at::vec { // accessed as `at::vec`. inline namespace CPU_CAPABILITY { -#if defined(CPU_CAPABILITY_SVE) +#if defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE128) template <> struct is_vec_specialized_for : std::bool_constant {}; diff --git a/aten/src/ATen/cpu/vec/sve/vec_int.h b/aten/src/ATen/cpu/vec/sve/vec_int.h index f0bc42caa950..b246c4d59377 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_int.h +++ b/aten/src/ATen/cpu/vec/sve/vec_int.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace at::vec { // Note [CPU_CAPABILITY namespace] @@ -312,8 +314,9 @@ inline void convert(const int32_t* src, int64_t* dst, int64_t n) { svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); #pragma unroll - for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_32 = svwhilelt_b32(i, n); @@ -345,6 +348,8 @@ inline void convert(const int64_t* src, float* dst, int64_t n) { } } +#ifndef CPU_CAPABILITY_SVE128 + template <> inline void convert(const int32_t* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); @@ -362,6 +367,8 @@ inline void convert(const int32_t* src, float* dst, int64_t n) { } } +#endif + template <> inline void convert(const bool* src, int64_t* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); diff --git a/aten/src/ATen/cpu/vec/sve/vec_qint.h b/aten/src/ATen/cpu/vec/sve/vec_qint.h index 61cb63cb1e12..575ad7e0105b 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_qint.h +++ b/aten/src/ATen/cpu/vec/sve/vec_qint.h @@ -6,11 +6,13 @@ #include #include #include +#include #include #include #include #include +#include // This file defines Vectorized<> for the quantized types. // diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index 0bfe65cd1959..aeb7cf029177 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -1,6 +1,8 @@ #pragma once -#if defined(CPU_CAPABILITY_AVX512) +#if defined(__aarch64__) +#include +#elif defined(CPU_CAPABILITY_AVX512) #include #else #include diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h index 02f64af3bb08..e32453480c00 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -2,6 +2,7 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] +#include #include #include #include @@ -262,6 +263,13 @@ class Vectorized : public Vectorized16< c10::bit_cast(val6.x), c10::bit_cast(val7.x)}) {} +#ifdef CPU_CAPABILITY_SVE128 + Vectorized(svbfloat16_t v) : Vectorized16(svget_neonq(v)) {} + operator svbfloat16_t() const { + return svset_neonq(svundef_bf16(), values); + } +#endif + static Vectorized blendv( const Vectorized& a, const Vectorized& b, @@ -374,6 +382,22 @@ class Vectorized : public Vectorized16< Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; + +#if defined(CPU_CAPABILITY_SVE) && defined(CPU_CAPABILITY_SVE128) + + template + static Vectorized arange( + BFloat16 base = 0.f, + step_t step = static_cast(1)) { + __at_align__ BFloat16 buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svget_neonq(svld1_bf16(ptrue, reinterpret_cast(buffer))); + } + +#endif // CPU_CAPABILITY_SVE128 + }; // Vectorized inline std::tuple, Vectorized> convert_bfloat16_float( @@ -397,6 +421,24 @@ inline Vectorized convert_float_bfloat16( return Vectorized(at_vcombine_bf16(x1, x2)); } +inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + Vectorized bf16_vec = Vectorized::loadu(data); + auto floats = convert_bfloat16_float(bf16_vec); + out1 = std::get<0>(floats); + out2 = std::get<1>(floats); +} + template Vectorized binary_operator_via_float( Op op, @@ -579,6 +621,12 @@ Vectorized inline fnmsub( return -a * b - c; } +#else // + +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) + +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) + #endif // !defined(C10_MOBILE) && defined(__aarch64__) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h index 0ad0c892b06c..5b7df3feee34 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h @@ -4,7 +4,7 @@ namespace at::vec { inline namespace CPU_CAPABILITY { -#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +#if defined(__aarch64__) && (!defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE128)) template struct VecConvert< float, @@ -60,6 +60,6 @@ struct VecConvert { } }; -#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) +#endif // defined(__aarch64__) && (!defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE128)) } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h index c6c34222c5cf..7fe15f5a1569 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -4,12 +4,10 @@ // See Note [Do not compile initializers with AVX] #include +#include #include #include - -#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#include -#endif +#include // Sleef offers vectorized versions of some transcedentals // such as sin, cos, tan etc.. @@ -35,12 +33,6 @@ inline namespace CPU_CAPABILITY { #error "Big endian is not supported." #endif -#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code -#else -#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code -#endif - template struct BlendRegs { static float32x4_t impl( @@ -94,19 +86,25 @@ class Vectorized { operator float32x4_t() const { return values; } +#ifdef CPU_CAPABILITY_SVE128 + Vectorized(svfloat32_t v) : values(svget_neonq(v)) {} + operator svfloat32_t() const { + return svset_neonq(svundef_f32(), values); + } +#endif template static Vectorized blend( const Vectorized& a, const Vectorized& b) { Vectorized vec; - vec.values = BlendRegs < 0, - (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); - vec.values = BlendRegs < 1, - (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); - vec.values = BlendRegs < 2, - (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); - vec.values = BlendRegs < 3, - (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = + BlendRegs<0, (mask & 0x01) != 0>::impl(a.values, b.values, vec.values); + vec.values = + BlendRegs<1, (mask & 0x02) != 0>::impl(a.values, b.values, vec.values); + vec.values = + BlendRegs<2, (mask & 0x04) != 0>::impl(a.values, b.values, vec.values); + vec.values = + BlendRegs<3, (mask & 0x08) != 0>::impl(a.values, b.values, vec.values); return vec; } static Vectorized blendv( diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h index ab4a5a89cba7..fb5879c95887 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -4,7 +4,6 @@ // See Note [Do not compile initializers with AVX] #include -#include #include #include #include @@ -25,7 +24,6 @@ inline namespace CPU_CAPABILITY { // https://bugs.llvm.org/show_bug.cgi?id=45824 // Most likely we will do aarch32 support with inline asm. #if !defined(C10_MOBILE) && defined(__aarch64__) - #ifdef __BIG_ENDIAN__ #error "Big endian is not supported." #endif @@ -421,6 +419,24 @@ Vectorized inline operator+( #endif } +inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_fp16( + const c10::Half* data, + Vectorized& out1, + Vectorized& out2) { + Vectorized f16_vec = Vectorized::loadu(data); + auto floats = convert_half_float(f16_vec); + out1 = std::get<0>(floats); + out2 = std::get<1>(floats); +} + template <> Vectorized inline operator-( const Vectorized& a, @@ -656,6 +672,53 @@ Vectorized inline fnmsub( return -a * b - c; #endif } + +#else + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + convert(arr2, arr, K); \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + convert(arr, arr2, K); \ + return Vectorized::loadu(arr2); \ + } + +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } + +CONVERT_NON_VECTORIZED_INIT(Half, half) + +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) + #endif // !defined(C10_MOBILE) && defined(__aarch64__) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 50c3cc31a6c4..229d569fa3cd 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -6,18 +6,12 @@ #include #include -#if !( \ - defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \ - defined(CPU_CAPABILITY_ZVECTOR)) -#if defined(CPU_CAPABILITY_SVE256) -#include -#else -// clang-format off +#if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)) #include +#include #include #include #include -#endif #if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16) #include #endif @@ -31,8 +25,6 @@ // clang-format off #include #include -#include -// clang-format on #endif #include diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index 1306270de714..21bcb8839413 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -10,7 +10,7 @@ namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { -#if defined(CPU_CAPABILITY_AVX2) +#if defined(CPU_CAPABILITY_AVX2) && !defined(__aarch64__) template <> struct is_vec_specialized_for : std::bool_constant {}; @@ -268,9 +268,7 @@ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) #else // defined(CPU_CAPABILITY_AVX2) -#if !( \ - defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ - !defined(CPU_CAPABILITY_SVE256)) +#if !(defined(__aarch64__)) CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) #endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_half.h b/aten/src/ATen/cpu/vec/vec256/vec256_half.h index 3022d265b398..c0eb32d44e42 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_half.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_half.h @@ -268,9 +268,7 @@ LOAD_FP32_VECTORIZED_INIT(Half, fp16) #else // defined(CPU_CAPABILITY_AVX2) -#if !( \ - defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ - !defined(CPU_CAPABILITY_SVE256)) +#if !defined(__aarch64__) CONVERT_NON_VECTORIZED_INIT(Half, half) #endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index dafe444163eb..4b9d41c301a1 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -5,6 +5,13 @@ #include #include + +#ifdef __aarch64__ +#if defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE) +#include +#endif +#endif + #include #include @@ -915,7 +922,7 @@ Vectorized inline maximum( return a.maximum(b); } -#elif !defined(CPU_CAPABILITY_SVE256) +#else // NOTE: These are low-performance implementations that we fall back on // if we are not building with AVX2. This may not be an issue, because @@ -1372,12 +1379,18 @@ Vectorized inline maximum( return a.maximum(b); } -#endif // if defined(CPU_CAPABILITY_AVX2) - -#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +#if defined(__aarch64__) && \ + (defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE)) std::pair, Vectorized> inline convert_int8_to_float( at::vec::Vectorized src) { + +#ifdef CPU_CAPABILITY_SVE + svint8_t x = src; + auto s8x8 = vget_low_s8(svget_neonq(x)); +#else auto s8x8 = vld1_s8(src.operator const int8_t*()); +#endif + auto s16x8 = vmovl_s8(s8x8); auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); @@ -1402,7 +1415,14 @@ std::pair, Vectorized> inline convert_int8_to_float( Vectorized inline convert_int8_half_register_to_float( at::vec::Vectorized src) { + +#ifdef CPU_CAPABILITY_SVE + svint8_t x = src; + auto s8x8 = vget_low_s8(svget_neonq(x)); +#else auto s8x8 = vld1_s8(src.operator const int8_t*()); +#endif + auto s16x8 = vmovl_s8(s8x8); auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); @@ -1420,5 +1440,8 @@ Vectorized inline convert_int8_half_register_to_float( } #endif + +#endif // if defined(CPU_CAPABILITY_AVX2) + } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index bfecfa3f933a..792f3a950018 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -57,7 +57,7 @@ Windows llvm will not have this definition. #endif // These macros helped us unify vec_base.h -#ifdef CPU_CAPABILITY_AVX512 +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_SVE512) #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(64))) #elif defined(_WIN32) @@ -67,9 +67,18 @@ Windows llvm will not have this definition. #endif #define VECTOR_WIDTH 64 #define int_vector __m512i -#elif defined(__aarch64__) && \ - !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 -// SVE code expects 256-vectors; leave that set for SVE? +#elif defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_SVE256) +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#elif defined(CPU_CAPABILITY_SVE128) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_DEFAULT) +// Define alignment and vector width for SVE128/Default (e.g., NEON) #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(16))) #elif defined(_WIN32) @@ -78,17 +87,17 @@ Windows llvm will not have this definition. #define __at_align__ #endif #define VECTOR_WIDTH 16 -#else // CPU_CAPABILITY_AVX512 +#else +// Fallback: define default alignment and vector width #if defined(__GNUC__) -#define __at_align__ __attribute__((aligned(32))) +#define __at_align__ __attribute__((aligned(16))) // Fallback to 16-byte alignment #elif defined(_WIN32) -#define __at_align__ __declspec(align(32)) +#define __at_align__ __declspec(align(16)) #else #define __at_align__ #endif -#define VECTOR_WIDTH 32 -#define int_vector __m256i -#endif // CPU_CAPABILITY_AVX512 +#define VECTOR_WIDTH 16 +#endif namespace at::vec { // See Note [CPU_CAPABILITY namespace] @@ -174,6 +183,7 @@ template template #endif struct Vectorized { + private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; @@ -194,7 +204,7 @@ struct Vectorized { template < typename... Args, typename = std::enable_if_t<(sizeof...(Args) == size())>> - Vectorized(Args... vals) : values{vals...} {} + Vectorized(Args... vals) : values{static_cast(vals)...} {} Vectorized(const T (&arr)[kSize]) { std::memcpy(values, arr, sizeof(values)); } diff --git a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h b/aten/src/ATen/cpu/vec/vec_common_aarch64.h similarity index 69% rename from aten/src/ATen/cpu/vec/sve/vec_common_sve.h rename to aten/src/ATen/cpu/vec/vec_common_aarch64.h index 69ed5d061bd8..6043ef663a9d 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h +++ b/aten/src/ATen/cpu/vec/vec_common_aarch64.h @@ -8,13 +8,44 @@ #include #include -#if defined(CPU_CAPABILITY_SVE) -#include +#ifdef CPU_CAPABILITY_SVE128 + +#include + +#include + +#include + +#include + #include +#include + +#elif defined(CPU_CAPABILITY_SVE) // SVE bigger than 128 + #include + +#include + +#include #include -#include -#endif + +#else // NEON + +#include + +#include + +#include + +#include + +#endif // defined(CPU_CAPABILITY_SVE128) + +// ATen/cpu/vec/sve/vec_qint.h is broken +#include + +#include namespace at::vec { // Note [CPU_CAPABILITY namespace] @@ -27,6 +58,33 @@ namespace at::vec { // accessed as `at::vec`. inline namespace CPU_CAPABILITY { +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << ']'; + return stream; +} + #if defined(CPU_CAPABILITY_SVE) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -48,12 +106,6 @@ DEFINE_SVE_CAST(int32_t, s32, float, f32) DEFINE_SVE_CAST(int16_t, s16, float, f32) DEFINE_SVE_CAST(float, f32, double, f64) -#ifdef __ARM_FEATURE_BF16 -DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) -DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) -DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) -#endif // __ARM_FEATURE_BF16 - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -61,9 +113,8 @@ std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< double>> inline gather(const double* base_addr, const Vectorized& vindex_) { - svint64_t vindex = - svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); - return svld1_gather_s64index_f64(ptrue, base_addr, vindex); + svint64_t offsets = svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)); + return svld1_gather_s64offset_f64(ptrue, base_addr, offsets); } template @@ -71,9 +122,8 @@ std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< float>> inline gather(const float* base_addr, const Vectorized& vindex_) { - svint32_t vindex = - svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); - return svld1_gather_s32index_f32(ptrue, base_addr, vindex); + svint32_t offsets = svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)); + return svld1_gather_s32offset_f32(ptrue, base_addr, offsets); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -85,12 +135,12 @@ std:: const double* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { - svbool_t mask = + svbool_t valid_mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); - svint64_t vindex = - svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); - return svsel_f64( - mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); + svint64_t offsets = svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)); + svfloat64_t gathered = + svld1_gather_s64offset_f64(valid_mask, base_addr, offsets); + return svsel_f64(valid_mask, gathered, src); } template @@ -100,12 +150,12 @@ std:: const float* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { - svbool_t mask = + svbool_t valid_mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); - svint32_t vindex = - svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); - return svsel_f32( - mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); + svint32_t offsets = svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)); + svfloat32_t gathered = + svld1_gather_s32offset_f32(valid_mask, base_addr, offsets); + return svsel_f32(valid_mask, gathered, src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -173,9 +223,11 @@ std::pair< // group cols crossing lanes: // return {a0, b0, a1, b1, a2, b2, a3, b3} // {a4, b4, a5, b5, a6, b6, a7, b7} - return std::make_pair( - Vectorized(svzip1_bf16(a, b)), - Vectorized(svzip2_bf16(a, b))); + svbfloat16_t aReg = a; + svbfloat16_t bReg = b; + Vectorized c = svzip1_bf16(aReg, bReg); + Vectorized d = svzip2_bf16(aReg, bReg); + return std::make_pair(c, d); } #endif // __ARM_FEATURE_BF16 @@ -224,12 +276,27 @@ std::pair< // swap lanes: // return {a0, a1, a2, a3, a4, a5, a6, a7} // {b0, b1, b2, b3, b4, b5, b6, b7} - return std::make_pair( - Vectorized(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)), - Vectorized(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b))); + svbfloat16_t aReg = a; + svbfloat16_t bReg = b; + Vectorized c = svuzp1_bf16(aReg, bReg); + Vectorized d = svuzp2_bf16(aReg, bReg); + return std::make_pair(c, d); } #endif // __ARM_FEATURE_BF16 +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#define DEFINE_FLIP_FUNC(type, sve_func) \ + inline Vectorized flip(const Vectorized& v) { \ + return Vectorized(sve_func(v)); \ + } +// Use the macro to define the flip functions +DEFINE_FLIP_FUNC(float, svrev_f32) +DEFINE_FLIP_FUNC(double, svrev_f64) +DEFINE_FLIP_FUNC(int64_t, svrev_s64) +DEFINE_FLIP_FUNC(int32_t, svrev_s32) +DEFINE_FLIP_FUNC(int16_t, svrev_s16) +DEFINE_FLIP_FUNC(int8_t, svrev_s8) + #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 54fb610722d6..72bfb6037e2e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -1152,108 +1152,108 @@ void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size, cons } } // anonymous namespace -REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel) -REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel) -REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel) -REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel) -REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel) -REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel) - -REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl) -REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) -REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) -REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) -REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) -REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) - -REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel) -REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) -REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) -REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) -REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) -REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) - -REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel) -REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel) -REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel) -REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel) -REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel) -REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel) - -REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel) -REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel) -REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel) -REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel) -REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel) -REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel) - -REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl) -REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl) -REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl) -REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl) -REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl) -REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl) - -REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel) -REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel) -REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel) -REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel) -REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel) -REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel) - -REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel) -REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel) -REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel) -REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel) -REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel) -REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel) - -REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel) -REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel) -REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel) -REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel) -REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel) -REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel) - -REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel) -REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel) -REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel) -REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel) -REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel) -REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel) - -REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel) -REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel) -REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel) -REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel) -REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel) -REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel) - -REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel) -REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) -REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) -REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) -REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) -REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) - -REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel) -REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel) -REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel) -REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel) -REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel) -REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel) - -REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel) -REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel) -REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel) -REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel) -REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel) -REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel) - -REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel) -REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel) -REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel) -REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel) -REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel) +REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel); +REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_SVE_DISPATCH(cholesky_stub, &cholesky_kernel); + +REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); +REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_SVE_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); + +REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); +REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_SVE_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); + +REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); +REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_SVE_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); + +REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); +REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_SVE_DISPATCH(geqrf_stub, &geqrf_kernel); + +REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); +REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_SVE_DISPATCH(orgqr_stub, &orgqr_kernel_impl); + +REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); +REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_SVE_DISPATCH(ormqr_stub, &ormqr_kernel); + +REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); +REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_SVE_DISPATCH(lstsq_stub, &lstsq_kernel); + +REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); +REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_SVE_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); + +REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel); +REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_SVE_DISPATCH(lu_factor_stub, &lu_factor_kernel); + +REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel); +REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_SVE_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); + +REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel); +REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_SVE_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); + +REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); +REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_SVE_DISPATCH(lu_solve_stub, &lu_solve_kernel); + +REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel); +REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel); +REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel); +REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel); +REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel); +REGISTER_SVE_DISPATCH(svd_stub, &svd_kernel); + +REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel); +REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_SVE_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 12ad01149321..0092a601c818 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -32,26 +32,27 @@ static CPUCapability compute_cpu_capability() { #if defined(HAVE_VSX_CPU_DEFINITION) if (envar == "vsx") { return CPUCapability::VSX; - } + } #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) if (envar == "zvector") { return CPUCapability::ZVECTOR; } #elif defined(HAVE_SVE_CPU_DEFINITION) int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW. -#ifdef HAVE_SVE256_CPU_DEFINITION - if (envar == "sve256") { - if (sve_vl == 256) { -#ifdef HAVE_ARM_BF16_CPU_DEFINITION + if (envar == "sve") { + if (sve_vl == 512) { + return CPUCapability::SVE512; + } else if (sve_vl == 256) { if (cpuinfo_has_arm_bf16()) { return CPUCapability::SVE256; } -#endif + } else if (sve_vl == 128) { + return CPUCapability::SVE128; + } else { + TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT"); + return CPUCapability::DEFAULT; } - TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT"); - return CPUCapability::DEFAULT; } -#endif #else #ifdef HAVE_AVX512_CPU_DEFINITION if (envar == "avx512") { @@ -105,6 +106,11 @@ static CPUCapability compute_cpu_capability() { // Return the default CPU capability. return CPUCapability::DEFAULT; } + #ifdef HAVE_SVE512_CPU_DEFINITION + if (sve_vl == 512) { // Check for SVE512 + return CPUCapability::SVE512; + } + #endif #ifdef HAVE_SVE256_CPU_DEFINITION if (sve_vl == 256) { // Check for SVE256 #ifdef HAVE_ARM_BF16_CPU_DEFINITION @@ -113,6 +119,11 @@ static CPUCapability compute_cpu_capability() { #endif } #endif + #ifdef HAVE_SVE128_CPU_DEFINITION + if (sve_vl == 128) { // Check for SVE128 + return CPUCapability::SVE128; + } + #endif // Return the default CPU capability. return CPUCapability::DEFAULT; } @@ -144,9 +155,15 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 #endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 +#endif ) { constexpr auto supported_devices = c10::array_of( c10::DeviceType::CPU, @@ -182,8 +199,14 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , SVE128 #endif ); if (!std::holds_alternative(result)) { @@ -239,9 +262,15 @@ void* DispatchStubImpl::get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 #endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 +#endif ) { auto result = try_get_call_ptr( @@ -263,9 +292,17 @@ void* DispatchStubImpl::get_call_ptr( , ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , + SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , + SVE128 #endif ); if (std::holds_alternative(result)) { @@ -298,8 +335,14 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 #endif ){ @@ -333,6 +376,16 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel; } #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE512)) { + if (C10_UNLIKELY(!SVE512)) { + // dispatch to DEFAULT, since the SVE kernel is missing + return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; + } else { + return DispatchResult(SVE512); + } + } +#endif #ifdef HAVE_SVE256_CPU_DEFINITION if (capability >= static_cast(CPUCapability::SVE256)) { if (C10_UNLIKELY(!SVE256)) { @@ -342,6 +395,16 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( return DispatchResult(SVE256); } } +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE128)) { + if (C10_UNLIKELY(!SVE128)) { + // dispatch to DEFAULT, since the SVE kernel is missing + return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; + } else { + return DispatchResult(SVE128); + } + } #endif return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; } @@ -360,9 +423,15 @@ void* DispatchStubImpl::choose_cpu_impl( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 #endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 +#endif ) { auto capability = static_cast(get_cpu_capability()); (void)capability; @@ -398,6 +467,17 @@ void* DispatchStubImpl::choose_cpu_impl( return ZVECTOR; } #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE512)) { + if (C10_UNLIKELY(!SVE512)) { + // dispatch to DEFAULT, since the SVE kernel is missing + TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); + return DEFAULT; + } else { + return SVE512; + } + } +#endif #ifdef HAVE_SVE256_CPU_DEFINITION if (capability >= static_cast(CPUCapability::SVE256)) { if (C10_UNLIKELY(!SVE256)) { @@ -408,6 +488,17 @@ void* DispatchStubImpl::choose_cpu_impl( return SVE256; } } +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE128)) { + if (C10_UNLIKELY(!SVE128)) { + // dispatch to DEFAULT, since the SVE kernel is missing + TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); + return DEFAULT; + } else { + return SVE128; + } + } #endif TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 4c1e403973d1..e7be1c4bc9c7 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -64,8 +64,10 @@ enum class CPUCapability { VSX = 1, #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) ZVECTOR = 1, -#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION) - SVE256 = 1, +#elif (defined(HAVE_SVE128_CPU_DEFINITION) || defined(HAVE_SVE256_CPU_DEFINITION) || defined(HAVE_SVE512_CPU_DEFINITION)) && defined(HAVE_ARM_BF16_CPU_DEFINITION) + SVE128 = 1, + SVE256 = 2, + SVE512 = 3, #else AVX2 = 1, AVX512 = 2, @@ -115,8 +117,14 @@ struct TORCH_API DispatchStubImpl { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 #endif ); @@ -136,8 +144,14 @@ struct TORCH_API DispatchStubImpl { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 #endif ); @@ -157,8 +171,14 @@ struct TORCH_API DispatchStubImpl { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 #endif ); @@ -181,8 +201,14 @@ struct TORCH_API DispatchStubImpl { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , void *SVE512 +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , void *SVE256 +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , void *SVE128 #endif ); @@ -238,8 +264,14 @@ struct DispatchStub { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , reinterpret_cast(SVE512) +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , reinterpret_cast(SVE256) +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , reinterpret_cast(SVE128) #endif ) ); @@ -299,8 +331,14 @@ struct DispatchStub { #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + , reinterpret_cast(SVE512) +#endif #ifdef HAVE_SVE256_CPU_DEFINITION , reinterpret_cast(SVE256) +#endif +#ifdef HAVE_SVE128_CPU_DEFINITION + , reinterpret_cast(SVE128) #endif ); if (std::holds_alternative(result)){ @@ -322,9 +360,15 @@ struct DispatchStub { #ifdef HAVE_ZVECTOR_CPU_DEFINITION static TORCH_API FnPtr ZVECTOR; #endif +#ifdef HAVE_SVE512_CPU_DEFINITION + static TORCH_API FnPtr SVE512; +#endif #ifdef HAVE_SVE256_CPU_DEFINITION static TORCH_API FnPtr SVE256; #endif +#ifdef HAVE_SVE128_CPU_DEFINITION + static TORCH_API FnPtr SVE128; +#endif private: DispatchStubImpl impl; }; @@ -426,12 +470,24 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_ZVECTOR_DISPATCH(name, fn) #endif +#ifdef HAVE_SVE512_CPU_DEFINITION +#define REGISTER_SVE512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE512, fn) +#else +#define REGISTER_SVE512_DISPATCH(name, fn) +#endif + #ifdef HAVE_SVE256_CPU_DEFINITION #define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn) #else #define REGISTER_SVE256_DISPATCH(name, fn) #endif +#ifdef HAVE_SVE128_CPU_DEFINITION +#define REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE128, fn) +#else +#define REGISTER_SVE128_DISPATCH(name, fn) +#endif + // Macro to register the same kernel for all CPU arch types. This is useful // if a kernel does not benefit from being recompiled across different arch types. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ @@ -440,7 +496,14 @@ struct RegisterPRIVATEUSE1Dispatch { REGISTER_AVX2_DISPATCH(name, fn) \ REGISTER_VSX_DISPATCH(name, fn) \ REGISTER_ZVECTOR_DISPATCH(name, fn) \ - REGISTER_SVE256_DISPATCH(name, fn) + REGISTER_SVE512_DISPATCH(name, fn) \ + REGISTER_SVE256_DISPATCH(name, fn) \ + REGISTER_SVE128_DISPATCH(name, fn) + +#define REGISTER_SVE_DISPATCH(name, fn) \ + REGISTER_SVE128_DISPATCH(name, fn) \ + REGISTER_SVE256_DISPATCH(name, fn) \ + REGISTER_SVE512_DISPATCH(name, fn) #define REGISTER_NO_CPU_DISPATCH(name) \ REGISTER_ALL_CPU_DISPATCH(name, nullptr) @@ -481,14 +544,18 @@ struct RegisterPRIVATEUSE1Dispatch { #elif defined(CPU_CAPABILITY) // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. +// ALSO_REGISTER_SVE512_DISPATCH should be used for ensuring SVE512 dispatch, among others. // ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others. +// ALSO_REGISTER_SVE128_DISPATCH should be used for ensuring SVE128 dispatch, among others. #ifdef CPU_CAPABILITY_AVX512 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) #else #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define ALSO_REGISTER_SVE512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define ALSO_REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif } // namespace at::native diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 2b61bcec6a82..eb34f8a742cb 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -461,23 +461,23 @@ Tensor segment_reduce_kernel( REGISTER_ARCH_DISPATCH( _segment_reduce_lengths_stub, DEFAULT, - &_segment_reduce_lengths_cpu_kernel) -REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel) -REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel) -REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel) -REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel) + &_segment_reduce_lengths_cpu_kernel); +REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_SVE_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); // offsets dispatches REGISTER_ARCH_DISPATCH( _segment_reduce_offsets_stub, DEFAULT, - &_segment_reduce_offsets_cpu_kernel) -REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) -REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) -REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) -REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) + &_segment_reduce_offsets_cpu_kernel); +REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_SVE_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); // Currently some computation is being duplicated across forward and backward. // TODO: Cache indices in forward pass to reuse in backward @@ -547,8 +547,8 @@ REGISTER_VSX_DISPATCH( &_segment_reduce_cpu_lengths_backward_kernel) REGISTER_ZVECTOR_DISPATCH( _segment_reduce_lengths_backward_stub, - &_segment_reduce_cpu_lengths_backward_kernel) -REGISTER_SVE256_DISPATCH( + &_segment_reduce_cpu_lengths_backward_kernel); +REGISTER_SVE_DISPATCH( _segment_reduce_lengths_backward_stub, &_segment_reduce_cpu_lengths_backward_kernel) @@ -567,8 +567,8 @@ REGISTER_VSX_DISPATCH( &_segment_reduce_cpu_offsets_backward_kernel) REGISTER_ZVECTOR_DISPATCH( _segment_reduce_offsets_backward_stub, - &_segment_reduce_cpu_offsets_backward_kernel) -REGISTER_SVE256_DISPATCH( + &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_SVE_DISPATCH( _segment_reduce_offsets_backward_stub, &_segment_reduce_cpu_offsets_backward_kernel) diff --git a/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp b/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp index 7126c1f7b5c3..7bfba4d5ec34 100644 --- a/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp +++ b/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index ab3b16c395a3..1afa777ee8f3 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -552,7 +553,6 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i }} // namespace cpublas::(anonymous) - REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl) REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl) REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl) diff --git a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp index ed5658f5f0f5..80feabe96630 100644 --- a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp +++ b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp @@ -74,6 +74,7 @@ float reduce(vec::VectorizedN& x) { } }); const auto [t0, t1] = vec::convert_half_float(x[0]); + return vec::vec_reduce_all( std::plus>(), t0 + t1); @@ -126,6 +127,7 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, #endif // !defined(__aarch64__) || defined( __ARM_FEATURE_FP16_SCALAR_ARITHMETIC) float reduce(vec::Vectorized x) { + return vec::vec_reduce_all( std::plus>(), x); @@ -212,7 +214,7 @@ std::pair, vec::Vectorized> fmadd( const vec::Vectorized& b, const vec::Vectorized& acc_low, const vec::Vectorized& acc_high) { -#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128))) return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b)); #else const auto [a_float_low, a_float_high] = convert_half_float(a); @@ -233,7 +235,7 @@ std::pair, vec::Vectorized> fmadd( // Return a + b_low * c_low + b_high * c_high vec::Vectorized fmadd(vec::Vectorized a, vec::Vectorized b, vec::Vectorized c) { -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128))) // NOTE: this instruction is an optional instruction in ARM v8.2 and // v8.3, but mandatory in v8.4 per // https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 4aa53c5e794b..9fb4e838f841 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -165,7 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) -REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_SVE_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) // _out variants can be shared between PocketFFT and MKL Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, diff --git a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp index 35e3ebaa9f8b..b9b85b30e1cb 100644 --- a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp +++ b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp @@ -22,11 +22,11 @@ Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) { } -REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel) -REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel) -REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel) -REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel) -REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel) +REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel); +REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_SVE_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index cf854a84e7da..f8250e59b989 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -156,24 +156,24 @@ void sparse_mask_projection_out_cpu_kernel( } -REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel) -REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel) -REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel) -REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel) -REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel) - -REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel) -REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel) -REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel) -REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel) -REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel) - -REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel) -REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel) -REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel) -REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel) -REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel) -REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel) +REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_SVE_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); + +REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_SVE_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); + +REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel); +REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_SVE_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); } diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d..b6fdf60e5555 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -443,12 +443,12 @@ int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Ten return static_cast(backend); } -REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp) -REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) -REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) -REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) -REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) -REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) +REGISTER_ARCH_DISPATCH(_fused_sdp_choice_stub, DEFAULT, &_fused_sdp_choice_cpp); +REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_SVE_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta) int64_t _fused_sdp_choice_meta( diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index b7b756f74ba1..b93ff1f179b6 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -1017,6 +1017,57 @@ namespace { RESOLVE_OVERLOAD(filter_fmadd)); } #endif +#if defined(CPU_CAPABILITY_SVE) + // Test for counting the number of active lanes using svcntw + TYPED_TEST(BitwiseFloatsAdditional, SVE_CountActiveLanes) { + CACHE_ALIGN int actual_vals[1]; + + // Get the number of active lanes in the vector register + int lane_count = svcntw(); // svcntw returns the active lane count + actual_vals[0] = lane_count; + + // Validate that the number of active lanes is a valid number + EXPECT_EQ(actual_vals[0], lane_count) << "svcntw should return a valid number of active lanes"; + } + // Test for loading and storing data using SVE intrinsics + TYPED_TEST(BitwiseFloatsAdditional, SVE_LoadStore) { + using VT = float; + int num_elements = svcntw(); // Get number of lanes dynamically + CACHE_ALIGN VT input[num_elements], output[num_elements]; + + // Initialize input array with data + for (int i = 0; i < num_elements; i++) { + input[i] = static_cast(i); // Fill input with values from 0 to num_elements-1 + output[i] = 0.0f; // Initialize output array to 0 + } + + svbool_t pg = svptrue_b32(); // Predicate enabling all lanes + svfloat32_t vec = svld1(pg, input); // Load data from input array + svst1(pg, output, vec); // Store the data into output array + + // Verify that the data loaded and stored correctly + for (int i = 0; i < num_elements; i++) { + EXPECT_EQ(input[i], output[i]) << "Mismatch at index " << i; + } + } + // Test for manipulating predicate lanes and counting them + TYPED_TEST(BitwiseFloatsAdditional, SVE_PredicateLaneManipulation) { + CACHE_ALIGN int actual_vals[1]; + + svbool_t pg_all = svptrue_b32(); // Create a predicate for all lanes enabled + // Create a predicate for a subset of lanes, here using svwhilelt + // This is used to select lanes where the index is less than the number of active lanes + svbool_t pg_alt = svwhilelt_b32(0, static_cast(svcntw())); // Cast to int + + int active_lanes_all = svcntw(); // Count the active lanes using svcntw + int active_lanes_alt = svcntp_b32(pg_alt, pg_all); // Count active lanes in predicate pg_alt + + actual_vals[0] = active_lanes_alt; // Store the result of active lane count in actual_vals + + // Verify that the active lane count in both predicates is the same + EXPECT_EQ(active_lanes_all, actual_vals[0]) << "Mismatch in active lane count."; + } +#endif // CPU_CAPABILITY_SVE template typename std::enable_if_t<(mask < 0 || mask> 255), void> // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) @@ -2269,7 +2320,6 @@ namespace { #undef TEST_MASK_LOAD #undef TEST_MASK_LOAD_N } -#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, MaskedCheck) { using VT = ValueType; using vec = TypeParam; @@ -2293,8 +2343,6 @@ namespace { #undef TEST_MASK_CHECK_N } -#endif -#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, ToFrom) { using vec = TypeParam; using VT = ValueType; @@ -2320,8 +2368,6 @@ namespace { << "Failure Details:\nTest Seed to reproduce: " << seed; } } -#endif -#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, Cast) { using vec = TypeParam; using src_t = ValueType; @@ -2366,7 +2412,6 @@ namespace { #undef TEST_MASK_CAST #undef TEST_MASK_CAST_N } -#endif #else #error GTEST does not have TYPED_TEST #endif diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7206cc34097..c6850f62cea5 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -17,7 +17,7 @@ #include #include -#if defined(CPU_CAPABILITY_AVX512) +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_SVE) #define CACHE_LINE 64 #else #define CACHE_LINE 32 @@ -32,7 +32,7 @@ #define CACHE_ALIGN __declspec(align(CACHE_LINE)) #define not_inline __declspec(noinline) #else -CACHE_ALIGN #define +#define CACHE_ALIGN #define not_inline #endif #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) diff --git a/caffe2/perfkernels/batch_box_cox_Sve128.cc b/caffe2/perfkernels/batch_box_cox_Sve128.cc new file mode 100644 index 000000000000..cc6b8fffdb40 --- /dev/null +++ b/caffe2/perfkernels/batch_box_cox_Sve128.cc @@ -0,0 +1,203 @@ +#if defined(CPU_CAPABILITY_SVE128) && defined(CAFFE2_PERF_WITH_SVE128) +#include +#include +#include + +// Log and exp approximations inspired from ACL implementation + +inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) +{ + const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f); + const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f); + const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f); + const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f); + const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f); + const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f); + const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f); + const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f); + + float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x); + float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x); + float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x); + float32x4_t D = vmlaq_f32(log_tab_4, log_tab_8, x); + float32x4_t x2 = vmulq_f32(x, x); + float32x4_t x4 = vmulq_f32(x2, x2); + float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4); + return res; +} + +inline float32x4_t vlogq_f32(float32x4_t x) +{ + const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2) + + // Extract exponent + int32x4_t m = svget_neonq(svsub_n_s32_x(svptrue_b8(), svset_neonq(svundef_s32(), vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))), 127)); + float32x4_t val = vreinterpretq_f32_s32(vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23))); + + // Polynomial Approximation + float32x4_t poly = vtaylor_polyq_for_log_f32(val); + + // Reconstruct + poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2); + + return poly; +} + +inline float32x4_t vexpq_f32(float32x4_t x) +{ + const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6))); + const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb))); + const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33))); + const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17))); + const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010))); + + const auto shift = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = vfmaq_f32(shift, x, inv_ln2); + const auto n = z - shift; + const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n + + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi); + const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo); + + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = r * r; + + const auto p1 = c1 * r; + const auto p23 = vfmaq_f32(c2, c3, r); + const auto p45 = vfmaq_f32(c4, c5, r); + const auto p2345 = vfmaq_f32(p23, p45, r2); + const auto p12345 = vfmaq_f32(p1, p2345, r2); + + auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale)); + + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input), inf, poly); + + return svget_neonq(poly); +} + +// ln(x) = log2(x) * ln(2) +// pow(x, n) = exp(n * ln(x)) +inline float32x4_t compute_batch_box_cox_vec_sve128_float(svfloat32_t lambda1_v, svfloat32_t lambda2_v, svfloat32_t data_v, svfloat32_t k_eps) { + data_v = svset_neonq(svundef_f32(), vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v))); + svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0); + svbool_t predNan = svcmpuo_f32(svptrue_b8(), data_v, data_v); + lambda1_v = svset_neonq(svundef_f32(), vrecpeq_f32(svget_neonq(lambda1_v))); + predNan = svnot_b_z(svptrue_b8(), predNan); + data_v = svmax_f32_m(predNan, data_v, k_eps); + svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(svget_neonq(data_v))); + if (__builtin_expect(svptest_any(predNZ, predNZ), 1)) { + float32x4_t pow = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v)); + pow = vexpq_f32(pow); + float32x4_t fms = svget_neonq(lambda1_v); + fms = vfmsq_f32(fms, pow, fms); + lnData = svsel_f32(predNZ, svset_neonq(svundef_f32(), fms), lnData); + } + return svget_neonq(lnData); +} + +template +void compute_batch_box_cox_vec_sve128( + std::size_t N, + std::size_t D, + const T* data_ptr, + const T* __restrict lambda1_ptr, + const T* __restrict lambda2_ptr, + T* output_ptr); + +template <> +void compute_batch_box_cox_vec_sve128( + std::size_t N, + std::size_t D, + const float* data_ptr, + const float* __restrict lambda1_ptr, + const float* __restrict lambda2_ptr, + float* output_ptr) { + svfloat32_t k_eps = svdup_n_f32(static_cast(1e-6)); + + std::size_t remainder = D % 4; + std::size_t loopBound = D - remainder; + svbool_t remainderPred = svwhilelt_b32_u64(0, remainder); + + for (std::size_t i = 0; i < N; i++) { + for (std::size_t j = 0; __builtin_expect(j != loopBound, 1); j+=4, data_ptr+=4, output_ptr+=4) { + svfloat32_t lambda1_v = svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j)); + svfloat32_t lambda2_v = svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j)); + svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr)); + float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps); + vst1q_f32(output_ptr, result); + } + if (__builtin_expect(remainder > 0, 1)) { + svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound); + svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound); + svfloat32_t data_v = svld1_f32(remainderPred, data_ptr); + float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps); + svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result)); + data_ptr += remainder; + output_ptr += remainder; + } + } +} + +namespace caffe2::details { + +template +void compute_batch_box_cox__sve128( + std::size_t N, + std::size_t D, + const T* self_data, + const T* __restrict lambda1_data, + const T* __restrict lambda2_data, + T* output_data) { + compute_batch_box_cox_vec_sve128( + N, + D, + self_data, + lambda1_data, + lambda2_data, + output_data); +} + +// Vectorized version specializations for float and double +template +void compute_batch_box_cox__sve128( + std::size_t N, + std::size_t D, + const float* self_data, + const float* __restrict lambda1_data, + const float* __restrict lambda2_data, + float* output_data); + +} // namespace caffe2::detail + + +#endif // CAFFE2_PERF_USE_MKL && CPU_CAPABILITY_SVE128 diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index e4973c849a18..e31e97851a6a 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -1,4 +1,4 @@ -# This ill-named file does a number of things: + # This ill-named file does a number of things: # - Installs Caffe2 header files (this has nothing to do with code generation) # - Configures caffe2/core/macros.h # - Creates an ATen target for its generated C++ files and adds it @@ -393,14 +393,22 @@ if(INTERN_BUILD_ATEN_OPS) LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}") endif(CXX_ZVECTOR_FOUND) - if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND) - list(APPEND CPU_CAPABILITY_NAMES "SVE256") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION") - if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - else() - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256") - endif() + if(CXX_SVE_FOUND AND CXX_ARM_BF16_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION") + + # Define all vector lengths since SVE is detected + list(APPEND CPU_CAPABILITY_NAMES "SVE128" "SVE256" "SVE512") + + foreach(VLEN IN ITEMS 128 256 512) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE${VLEN}_CPU_DEFINITION") + if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=${VLEN}") + else() + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=${VLEN}") + endif() + endforeach() + + message(STATUS "SVE support enabled for 128, 256, and 512 vector lengths.") endif() list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index 903025c5c2cf..da7c51b0864c 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -89,7 +89,7 @@ if(NOT CORTEXA9_FOUND) endif(NOT CORTEXA9_FOUND) mark_as_advanced(NEON_FOUND) -#SVE support is available is only for Linux OS. +# SVE support is available only for Linux OS. IF(CMAKE_SYSTEM_NAME MATCHES "Linux") # Include necessary modules for checking C and C++ source compilations INCLUDE(CheckCSourceCompiles) @@ -117,7 +117,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ") # Macro to check for SVE instruction support - MACRO(CHECK_COMPILES lang type flags code) + MACRO(CHECK_COMPILES lang flags code) # Save the current state of required flags SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) @@ -126,44 +126,37 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") # Check if the source code compiles with the given flags for the specified language (C or C++) IF(lang STREQUAL "CXX") - CHECK_CXX_SOURCE_COMPILES("${code}" ${lang}_HAS_${type}) + CHECK_CXX_SOURCE_COMPILES("${code}" ${lang}_SVE_FOUND) ELSE() - CHECK_C_SOURCE_COMPILES("${code}" ${lang}_HAS_${type}) - ENDIF() - - # If the compilation test is successful, set appropriate variables indicating support - IF(${lang}_HAS_${type}) - set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") - SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") - SET(${lang}_${type}_FLAGS "${flags}" CACHE STRING "${lang} ${type} flags") + CHECK_C_SOURCE_COMPILES("${code}" ${lang}_SVE_FOUND) ENDIF() # Restore the original state of required flags SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) - # If the compilation test fails, indicate that the support is not found - IF(NOT ${lang}_${type}_FOUND) - SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support") - SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags") + # Set the result variables + IF(${lang}_SVE_FOUND) + set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + SET(${lang}_SVE_FLAGS "${flags}" CACHE STRING "${lang} SVE flags") + ELSE() + set(${lang}_SVE_FOUND FALSE CACHE BOOL "SVE not available on host") + SET(${lang}_SVE_FLAGS "" CACHE STRING "${lang} SVE flags") ENDIF() - # Mark the variables as advanced to hide them in the default CMake GUI - MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) + # Mark as advanced to hide the variables in the default CMake GUI + MARK_AS_ADVANCED(${lang}_SVE_FOUND ${lang}_SVE_FLAGS) ENDMACRO() - # Check for SVE256 vector length - CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}") + # Check for SVE support using the base flag `-march=armv8-a+sve` + CHECK_COMPILES(CXX "-march=armv8.2-a+sve" "${SVE_CODE}") CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}") - # If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user - if(NOT CXX_SVE256_FOUND) - set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host") - message(STATUS "No SVE processor on this machine.") - else() - # If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user - set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host") - message(STATUS "SVE support detected.") - endif() + # Assume all vector lengths are supported if SVE is detected + IF(CXX_SVE_FOUND) + message(STATUS "SVE support detected. The current toolchain can generate SVE instructions") + ELSE() + message(STATUS "Current toolchain could not be used to generate SVE instructions.") + ENDIF() # Mark the SVE support variable as advanced mark_as_advanced(CXX_SVE_FOUND) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 13d946863425..22529a54c3cd 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -15,7 +15,9 @@ VecAVX512, VecISA, VecNEON, + VecSVE128, VecSVE256, + VecSVE512, ) from ..utils import IndentedBuffer, parallel_num_threads from ..virtualized import V @@ -222,7 +224,6 @@ def codegen_allocate_weight_buffer( def is_woq_int4(self): return False - @dataclasses.dataclass class CppMicroGemmConfig: input_dtype: torch.dtype @@ -424,6 +425,14 @@ def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, ** output_dtype=torch.float, compute_dtype=torch.float, ), + *generate_gemm_config( + VecSVE128, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), *generate_gemm_config( VecSVE256, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], @@ -432,6 +441,14 @@ def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, ** output_dtype=torch.float, compute_dtype=torch.float, ), + *generate_gemm_config( + VecSVE512, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), ) class CppMicroGemmFP32Vec(CppMicroGemm): """ @@ -1019,6 +1036,7 @@ class CppMicroGemmAMX(CppMicroGemm): // we cache K * {{block_n}} elements of dequantized B {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} const auto buf_size = K * {{block_n}}; + auto load_dequantized_B = [&](int base_idx) { // Load a tile of B & cache it in L1D. {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx; @@ -1355,7 +1373,6 @@ def get_b_layout(self): assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported() return LayoutType.VNNI2 - def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): if alpha != 1: return False diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index b077c4da9c28..b3106332c931 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -160,7 +160,7 @@ def __bool__impl(self, vec_isa_ok) -> bool: @dataclasses.dataclass class VecNEON(VecISA): _bit_width = 128 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h - _macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VEC256_WITH_SLEEF"] + _macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VECSVE_WITH_SLEEF"] _arch_flags = "" # Unused _dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8} @@ -172,14 +172,33 @@ def __str__(self) -> str: __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] +class VecSVE128(VecISA): + # SVE with 128-bit width + _bit_width = 128 + _macro = [ + "CPU_CAPABILITY_SVE", + "CPU_CAPABILITY_SVE128", + "AT_BUILD_ARM_VECSVE_WITH_SLEEF", + ] + _arch_flags = "-march=armv9-a+sve2 -msve-vector-bits=128" + _dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8} + + def __str__(self) -> str: + if config.is_fbcode(): + return "sve128" + return "asimd" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + @dataclasses.dataclass class VecSVE256(VecISA): - # this function can be repurposed for SVE with variable vec length + # SVE with 256-bit width _bit_width = 256 _macro = [ "CPU_CAPABILITY_SVE", "CPU_CAPABILITY_SVE256", - "AT_BUILD_ARM_VEC256_WITH_SLEEF", + "AT_BUILD_ARM_VECSVE_WITH_SLEEF", "__ARM_FEATURE_BF16", ] _arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256" @@ -187,9 +206,25 @@ class VecSVE256(VecISA): _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} def __str__(self) -> str: - if config.is_fbcode(): - return "neon" - return "asimd" + return "sve256" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecSVE512(VecISA): + # SVE with 512-bit width + _bit_width = 512 + _macro = [ + "CPU_CAPABILITY_SVE", + "CPU_CAPABILITY_SVE512", + "AT_BUILD_ARM_VECSVE_WITH_SLEEF", + ] + _arch_flags = "-march=armv8-a+sve -msve-vector-bits=512" + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + + def __str__(self) -> str: + return "sve512" __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @@ -344,7 +379,9 @@ def _check_and_append_supported_isa( VecAVX512(), VecAVX2(), VecNEON(), + VecSVE128(), VecSVE256(), + VecSVE512(), ] @@ -406,8 +443,12 @@ def valid_vec_isa_list() -> list[VecISA]: elif arch == "ppc64le": isa_list.append(VecVSX()) elif arch == "aarch64": - if torch.backends.cpu.get_cpu_capability() == "SVE256": + if torch.backends.cpu.get_cpu_capability() == "SVE128": + isa_list.append(VecSVE128()) + elif torch.backends.cpu.get_cpu_capability() == "SVE256": isa_list.append(VecSVE256()) + elif torch.backends.cpu.get_cpu_capability() == "SVE512": + isa_list.append(VecSVE512()) else: isa_list.append(VecNEON()) diff --git a/torch/backends/cpu/__init__.py b/torch/backends/cpu/__init__.py index 82dc52cd4904..9c67a6640dc7 100644 --- a/torch/backends/cpu/__init__.py +++ b/torch/backends/cpu/__init__.py @@ -16,6 +16,8 @@ def get_cpu_capability() -> str: - "NO AVX" - "AVX2" - "AVX512" + - "SVE128" - "SVE256" + - "SVE512" """ return torch._C._get_cpu_capability() diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index f98da60a1049..e5ade0282793 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -34,7 +34,8 @@ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || \ defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || \ - defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256) + defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE128) || \ + defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE512) #define INDUCTOR_USE_VECTOR_TYPES() 1 #else #define INDUCTOR_USE_VECTOR_TYPES() 0 diff --git a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp index e8d7170fdf1c..6861d2f546f1 100644 --- a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp index f919639f48de..589b36fde7b8 100644 --- a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp @@ -1,5 +1,6 @@ // @generated // @lint-ignore-every CLANGTIDY HOWTOEVEN +#include #include #include #include @@ -7,7 +8,6 @@ #include #include #include -#include #include #include #include