Skip to content

Commit 275a703

Browse files
committed
[PyTorch] Use 128-bit vectors for ARM64
Pull Request resolved: #137426 The correct vector length for ARM64 is 128 bits (16 bytes). We were previously using double this, apparently just because that would be the same length as AVX2. ghstack-source-id: 249869364 Differential Revision: [D63984039](https://our.internmc.facebook.com/intern/diff/D63984039/)
1 parent 56bffca commit 275a703

19 files changed

+844
-1313
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
5454
endif()
5555
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
5656

57-
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
57+
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec128/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
5858
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
5959
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
6060
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,19 @@ struct VecReduceAllSIMD<float, Op> {
8585
using Vec = Vectorized<float>;
8686
Vec v = acc_vec;
8787

88-
// 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4]
89-
Vec v1 = {v.get_high(), v.get_low()};
90-
// [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required)
91-
v = vec_fun(v, v1);
92-
9388
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
94-
float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2);
95-
v1 = {v1_1, v1_1};
89+
float32x4_t v1_1 = vextq_f32(v, v, 2);
90+
Vec v1 = v1_1;
9691
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
9792
v = vec_fun(v, v1);
9893

9994
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
100-
v1_1 = vrev64q_f32(v.get_low());
101-
v1 = {v1_1, v1_1};
95+
v1_1 = vrev64q_f32(v);
96+
v1 = v1_1;
10297
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
10398
v = vec_fun(v, v1);
10499

105-
return v.get_low()[0];
100+
return v[0];
106101
}
107102
};
108103
#endif // defined(__aarch64__)

aten/src/ATen/cpu/vec/vec.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#if defined(CPU_CAPABILITY_AVX512)
44
#include <ATen/cpu/vec/vec512/vec512.h>
55
#else
6+
#include <ATen/cpu/vec/vec128/vec128.h>
67
#include <ATen/cpu/vec/vec256/vec256.h>
78
#endif
89

aten/src/ATen/cpu/vec/vec128/vec128.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
// ARM NEON uses 128-bit vector registers.
3+
4+
#include <ATen/cpu/vec/intrinsics.h>
5+
6+
#if !defined(CPU_CAPABILITY_SVE)
7+
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
8+
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
9+
#endif

0 commit comments

Comments
 (0)