Skip to content

Commit 43d3ff9

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
[Caffe2] Add float batch box cox SVE128 implementation (#159778)
Summary: Pull Request resolved: #159778 We are introducing SVE128 perfkernels As first translation, we are implementing float32 batch box cox for SVE128 Test Plan: Sigrid Predictor canary Rollback Plan: Differential Revision: D78994871 Privacy Context Container: L1196524
1 parent e684e84 commit 43d3ff9

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#if defined(CPU_CAPABILITY_SVE128) && defined(CAFFE2_PERF_WITH_SVE128)
2+
#include <arm_neon.h>
3+
#include <arm_sve.h>
4+
#include <arm_neon_sve_bridge.h>
5+
6+
// Log and exp approximations inspired from ACL implementation
7+
8+
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x)
9+
{
10+
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
11+
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
12+
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
13+
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
14+
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
15+
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
16+
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
17+
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
18+
19+
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
20+
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
21+
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
22+
float32x4_t D = vmlaq_f32(log_tab_4, log_tab_8, x);
23+
float32x4_t x2 = vmulq_f32(x, x);
24+
float32x4_t x4 = vmulq_f32(x2, x2);
25+
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
26+
return res;
27+
}
28+
29+
inline float32x4_t vlogq_f32(float32x4_t x)
30+
{
31+
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
32+
33+
// Extract exponent
34+
int32x4_t m = svget_neonq(svsub_n_s32_x(svptrue_b8(), svset_neonq(svundef_s32(), vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))), 127));
35+
float32x4_t val = vreinterpretq_f32_s32(vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
36+
37+
// Polynomial Approximation
38+
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
39+
40+
// Reconstruct
41+
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
42+
43+
return poly;
44+
}
45+
46+
inline float32x4_t vexpq_f32(float32x4_t x)
47+
{
48+
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
49+
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
50+
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
51+
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
52+
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
53+
54+
const auto shift = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
55+
const auto inv_ln2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
56+
const auto neg_ln2_hi =
57+
vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
58+
const auto neg_ln2_lo =
59+
vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
60+
61+
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
62+
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
63+
const auto zero = svdup_n_f32(0.f);
64+
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
65+
66+
// Range reduction:
67+
// e^x = 2^n * e^r
68+
// where:
69+
// n = floor(x / ln(2))
70+
// r = x - n * ln(2)
71+
//
72+
// By adding x / ln(2) with 2^23 + 127 (shift):
73+
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part
74+
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy
75+
// the whole fraction part of z in FP32 format.
76+
// Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2)
77+
// (i.e. n) because the decimal part has been pushed out and lost.
78+
// * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent
79+
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
80+
const auto z = vfmaq_f32(shift, x, inv_ln2);
81+
const auto n = z - shift;
82+
const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
83+
84+
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32.
85+
// This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance.
86+
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
87+
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
88+
89+
// Compute the truncated Taylor series of e^r.
90+
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
91+
const auto r2 = r * r;
92+
93+
const auto p1 = c1 * r;
94+
const auto p23 = vfmaq_f32(c2, c3, r);
95+
const auto p45 = vfmaq_f32(c4, c5, r);
96+
const auto p2345 = vfmaq_f32(p23, p45, r2);
97+
const auto p12345 = vfmaq_f32(p1, p2345, r2);
98+
99+
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
100+
101+
// Handle underflow and overflow.
102+
poly = svsel_f32(svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input), zero, poly);
103+
poly = svsel_f32(svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input), inf, poly);
104+
105+
return svget_neonq(poly);
106+
}
107+
108+
// ln(x) = log2(x) * ln(2)
109+
// pow(x, n) = exp(n * ln(x))
110+
inline float32x4_t compute_batch_box_cox_vec_sve128_float(svfloat32_t lambda1_v, svfloat32_t lambda2_v, svfloat32_t data_v, svfloat32_t k_eps) {
111+
data_v = svset_neonq(svundef_f32(), vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v)));
112+
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0);
113+
svbool_t predNan = svcmpuo_f32(svptrue_b8(), data_v, data_v);
114+
lambda1_v = svset_neonq(svundef_f32(), vrecpeq_f32(svget_neonq(lambda1_v)));
115+
predNan = svnot_b_z(svptrue_b8(), predNan);
116+
data_v = svmax_f32_m(predNan, data_v, k_eps);
117+
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(svget_neonq(data_v)));
118+
if (__builtin_expect(svptest_any(predNZ, predNZ), 1)) {
119+
float32x4_t pow = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
120+
pow = vexpq_f32(pow);
121+
float32x4_t fms = svget_neonq(lambda1_v);
122+
fms = vfmsq_f32(fms, pow, fms);
123+
lnData = svsel_f32(predNZ, svset_neonq(svundef_f32(), fms), lnData);
124+
}
125+
return svget_neonq(lnData);
126+
}
127+
128+
template <typename T>
129+
void compute_batch_box_cox_vec_sve128(
130+
std::size_t N,
131+
std::size_t D,
132+
const T* data_ptr,
133+
const T* __restrict lambda1_ptr,
134+
const T* __restrict lambda2_ptr,
135+
T* output_ptr);
136+
137+
template <>
138+
void compute_batch_box_cox_vec_sve128(
139+
std::size_t N,
140+
std::size_t D,
141+
const float* data_ptr,
142+
const float* __restrict lambda1_ptr,
143+
const float* __restrict lambda2_ptr,
144+
float* output_ptr) {
145+
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
146+
147+
std::size_t remainder = D % 4;
148+
std::size_t loopBound = D - remainder;
149+
svbool_t remainderPred = svwhilelt_b32_u64(0, remainder);
150+
151+
for (std::size_t i = 0; i < N; i++) {
152+
for (std::size_t j = 0; __builtin_expect(j != loopBound, 1); j+=4, data_ptr+=4, output_ptr+=4) {
153+
svfloat32_t lambda1_v = svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j));
154+
svfloat32_t lambda2_v = svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
155+
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
156+
float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps);
157+
vst1q_f32(output_ptr, result);
158+
}
159+
if (__builtin_expect(remainder > 0, 1)) {
160+
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
161+
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
162+
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
163+
float32x4_t result = compute_batch_box_cox_vec_sve128_float(lambda1_v, lambda2_v, data_v, k_eps);
164+
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
165+
data_ptr += remainder;
166+
output_ptr += remainder;
167+
}
168+
}
169+
}
170+
171+
namespace caffe2::details {
172+
173+
template <typename T>
174+
void compute_batch_box_cox__sve128(
175+
std::size_t N,
176+
std::size_t D,
177+
const T* self_data,
178+
const T* __restrict lambda1_data,
179+
const T* __restrict lambda2_data,
180+
T* output_data) {
181+
compute_batch_box_cox_vec_sve128<T>(
182+
N,
183+
D,
184+
self_data,
185+
lambda1_data,
186+
lambda2_data,
187+
output_data);
188+
}
189+
190+
// Vectorized version specializations for float and double
191+
template
192+
void compute_batch_box_cox__sve128<float>(
193+
std::size_t N,
194+
std::size_t D,
195+
const float* self_data,
196+
const float* __restrict lambda1_data,
197+
const float* __restrict lambda2_data,
198+
float* output_data);
199+
200+
} // namespace caffe2::detail
201+
202+
203+
#endif // CAFFE2_PERF_USE_MKL && CPU_CAPABILITY_SVE128

0 commit comments

Comments
 (0)