|
| 1 | +#if defined(CPU_CAPABILITY_SVE128) && defined(CAFFE2_PERF_WITH_SVE128) |
| 2 | +#include <arm_neon.h> |
| 3 | +#include <arm_sve.h> |
| 4 | +#include <arm_neon_sve_bridge.h> |
| 5 | + |
| 6 | +// Log and exp approximations inspired from ACL implementation |
| 7 | + |
| 8 | +inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) |
| 9 | +{ |
| 10 | + const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f); |
| 11 | + const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f); |
| 12 | + const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f); |
| 13 | + const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f); |
| 14 | + const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f); |
| 15 | + const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f); |
| 16 | + const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f); |
| 17 | + const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f); |
| 18 | + |
| 19 | + float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x); |
| 20 | + float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x); |
| 21 | + float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x); |
| 22 | + float32x4_t D = vmlaq_f32(log_tab_4, log_tab_8, x); |
| 23 | + float32x4_t x2 = vmulq_f32(x, x); |
| 24 | + float32x4_t x4 = vmulq_f32(x2, x2); |
| 25 | + float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4); |
| 26 | + return res; |
| 27 | +} |
| 28 | + |
| 29 | +inline float32x4_t vlogq_f32(float32x4_t x) |
| 30 | +{ |
| 31 | + const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2) |
| 32 | + |
| 33 | + // Extract exponent |
| 34 | + int32x4_t m = svget_neonq(svsub_n_s32_x(svptrue_b8(), svset_neonq(svundef_s32(), vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))), 127)); |
| 35 | + float32x4_t val = vreinterpretq_f32_s32(vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23))); |
| 36 | + |
| 37 | + // Polynomial Approximation |
| 38 | + float32x4_t poly = vtaylor_polyq_for_log_f32(val); |
| 39 | + |
| 40 | + // Reconstruct |
| 41 | + poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2); |
| 42 | + |
| 43 | + return poly; |
| 44 | +} |
| 45 | + |
| 46 | +inline float32x4_t vexpq_f32(float32x4_t x) |
| 47 | +{ |
| 48 | + const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6))); |
| 49 | + const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb))); |
| 50 | + const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33))); |
| 51 | + const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17))); |
| 52 | + const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010))); |
| 53 | + |
| 54 | + const auto shift = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f |
| 55 | + const auto inv_ln2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f |
| 56 | + const auto neg_ln2_hi = |
| 57 | + vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f |
| 58 | + const auto neg_ln2_lo = |
| 59 | + vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f |
| 60 | + |
| 61 | + const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity()); |
| 62 | + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) |
| 63 | + const auto zero = svdup_n_f32(0.f); |
| 64 | + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) |
| 65 | + |
| 66 | + // Range reduction: |
| 67 | + // e^x = 2^n * e^r |
| 68 | + // where: |
| 69 | + // n = floor(x / ln(2)) |
| 70 | + // r = x - n * ln(2) |
| 71 | + // |
| 72 | + // By adding x / ln(2) with 2^23 + 127 (shift): |
| 73 | + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part |
| 74 | + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy |
| 75 | + // the whole fraction part of z in FP32 format. |
| 76 | + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) |
| 77 | + // (i.e. n) because the decimal part has been pushed out and lost. |
| 78 | + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent |
| 79 | + // in FP32 format. Left shifting z by 23 bits will result in 2^n. |
| 80 | + const auto z = vfmaq_f32(shift, x, inv_ln2); |
| 81 | + const auto n = z - shift; |
| 82 | + const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n |
| 83 | + |
| 84 | + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. |
| 85 | + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. |
| 86 | + const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi); |
| 87 | + const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo); |
| 88 | + |
| 89 | + // Compute the truncated Taylor series of e^r. |
| 90 | + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) |
| 91 | + const auto r2 = r * r; |
| 92 | + |
| 93 | + const auto p1 = c1 * r; |
| 94 | + const auto p23 = vfmaq_f32(c2, c3, r); |
| 95 | + const auto p45 = vfmaq_f32(c4, c5, r); |
| 96 | + const auto p2345 = vfmaq_f32(p23, p45, r2); |
| 97 | + const auto p12345 = vfmaq_f32(p1, p2345, r2); |
| 98 | + |
| 99 | + auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale)); |
| 100 | + |
| 101 | + // Handle underflow and overflow. |
| 102 | + poly = svsel_f32(svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input), zero, poly); |
| 103 | + poly = svsel_f32(svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input), inf, poly); |
| 104 | + |
| 105 | + return svget_neonq(poly); |
| 106 | +} |
| 107 | + |
| 108 | +// ln(x) = log2(x) * ln(2) |
| 109 | +// pow(x, n) = exp(n * ln(x)) |
| 110 | +inline float32x4_t compute_batch_box_cox_vec_sve128_float(svfloat32_t lambda1_v, svfloat32_t lambda2_v, svfloat32_t data_v, svfloat32_t k_eps) { |
| 111 | + data_v = svset_neonq(svundef_f32(), vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v))); |
| 112 | + svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0); |
| 113 | + svbool_t predNan = svcmpuo_f32(svptrue_b8(), data_v, data_v); |
| 114 | + lambda1_v = svset_neonq(svundef_f32(), vrecpeq_f32(svget_neonq(lambda1_v))); |
| 115 | + predNan = svnot_b_z(svptrue_b8(), predNan); |
| 116 | + data_v = svmax_f32_m(predNan, data_v, k_eps); |
| 117 | + svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(svget_neonq(data_v))); |
| 118 | + if (__builtin_expect(svptest_any(predNZ, predNZ), 1)) { |
| 119 | + float32x4_t pow = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v)); |
| 120 | + pow = vexpq_f32(pow); |
| 121 | + float32x4_t fms = svget_neonq(lambda1_v); |
| 122 | + fms = vfmsq_f32(fms, pow, fms); |
| 123 | + lnData = svsel_f32(predNZ, svset_neonq(svundef_f32(), fms), lnData); |
| 124 | + } |
| 125 | + return svget_neonq(lnData); |
| 126 | +} |
| 127 | + |
| 128 | +template <typename T> |
| 129 | +void compute_batch_box_cox_vec_sve128( |
| 130 | + std::size_t N, |
| 131 | + std::size_t D, |
| 132 | + const T* data_ptr, |
| 133 | + const T* __restrict lambda1_ptr, |
| 134 | + const T* __restrict lambda2_ptr, |
| 135 | + T* output_ptr); |
| 136 | + |
| 137 | +template <> |
| 138 | +void compute_batch_box_cox_vec_sve128( |
| 139 | + std::size_t N, |
| 140 | + std::size_t D, |
| 141 | + const float* data_ptr, |
| 142 | + const float* __restrict lambda1_ptr, |
| 143 | + const float* __restrict lambda2_ptr, |
| 144 | + float* output_ptr) { |
| 145 | + svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6)); |
| 146 | + |
| 147 | + std::size_t remainder = D % 4; |
| 148 | + std::size_t loopBound = D - remainder; |
| 149 | + svbool_t remainderPred = svwhilelt_b32_u64(0, remainder); |
| 150 | + |
| 151 | + for (std::size_t i = 0; i < N; i++) { |
| 152 | + for (std::size_t j = 0; __builtin_expect(j != loopBound, 1); j+=4, data_ptr+=4, output_ptr+=4) { |
| 153 | + svfloat32_t lambda1_v = svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j)); |
| 154 | + svfloat32_t lambda2_v = svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j)); |
| 155 | + svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr)); |
| 156 | + float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps); |
| 157 | + vst1q_f32(output_ptr, result); |
| 158 | + } |
| 159 | + if (__builtin_expect(remainder > 0, 1)) { |
| 160 | + svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound); |
| 161 | + svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound); |
| 162 | + svfloat32_t data_v = svld1_f32(remainderPred, data_ptr); |
| 163 | + float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps); |
| 164 | + svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result)); |
| 165 | + data_ptr += remainder; |
| 166 | + output_ptr += remainder; |
| 167 | + } |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +namespace caffe2::details { |
| 172 | + |
| 173 | +template <typename T> |
| 174 | +void compute_batch_box_cox__sve128( |
| 175 | + std::size_t N, |
| 176 | + std::size_t D, |
| 177 | + const T* self_data, |
| 178 | + const T* __restrict lambda1_data, |
| 179 | + const T* __restrict lambda2_data, |
| 180 | + T* output_data) { |
| 181 | + compute_batch_box_cox_vec_sve128<T>( |
| 182 | + N, |
| 183 | + D, |
| 184 | + self_data, |
| 185 | + lambda1_data, |
| 186 | + lambda2_data, |
| 187 | + output_data); |
| 188 | +} |
| 189 | + |
| 190 | +// Vectorized version specializations for float and double |
| 191 | +template |
| 192 | +void compute_batch_box_cox__sve128<float>( |
| 193 | + std::size_t N, |
| 194 | + std::size_t D, |
| 195 | + const float* self_data, |
| 196 | + const float* __restrict lambda1_data, |
| 197 | + const float* __restrict lambda2_data, |
| 198 | + float* output_data); |
| 199 | + |
| 200 | +} // namespace caffe2::detail |
| 201 | + |
| 202 | + |
| 203 | +#endif // CAFFE2_PERF_USE_MKL && CPU_CAPABILITY_SVE128 |
0 commit comments