Skip to content

Commit e1e6417

Browse files
Sqvidpytorchmergebot
authored andcommitted
Add SVE implementation of embedding_lookup_idx (#133995)
Adds an accelerated version of the embedding_lookup_idx perfkernels. This is done via a python codegen file similarly to `caffe2/perfkernels/hp_emblookup_codegen.py` Pull Request resolved: #133995 Approved by: https://github.com/malfet, https://github.com/huydhn
1 parent b09d6f3 commit e1e6417

File tree

8 files changed

+7265
-9
lines changed

8 files changed

+7265
-9
lines changed

caffe2/perfkernels/CMakeLists.txt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ endif()
1010
file(GLOB common_srcs *.cc)
1111
file(GLOB avx_srcs *_avx.cc)
1212
file(GLOB avx2_srcs *_avx2.cc)
13-
# exclude avx and avx2 srcs from common_srcs
13+
file(GLOB avx512_srcs *_avx512.cc)
14+
file(GLOB sve_srcs *_sve.cc)
15+
# exclude avx, avx2, avx512, and sve srcs from common_srcs
1416
exclude(common_srcs "${common_srcs}" ${avx_srcs})
1517
exclude(common_srcs "${common_srcs}" ${avx2_srcs})
18+
exclude(common_srcs "${common_srcs}" ${avx512_srcs})
19+
exclude(common_srcs "${common_srcs}" ${sve_srcs})
1620

1721
# We will always build common srcs.
1822
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
@@ -42,6 +46,22 @@ if(CXX_AVX2_FOUND)
4246
"Caffe2_perfkernels_avx2_interface")
4347
endif()
4448

49+
# We will only build the SVE perfkernel files if the compiler supports SVE
50+
# extensions.
51+
if(CXX_SVE_FOUND)
52+
add_library(Caffe2_perfkernels_sve STATIC ${sve_srcs})
53+
target_link_libraries(Caffe2_perfkernels_sve PRIVATE c10)
54+
install(TARGETS Caffe2_perfkernels_sve
55+
ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}")
56+
57+
target_compile_options(Caffe2_perfkernels_sve PRIVATE "-march=armv8-a+sve")
58+
59+
caffe2_interface_library(
60+
Caffe2_perfkernels_sve Caffe2_perfkernels_sve_interface)
61+
list(APPEND
62+
Caffe2_DEPENDENCY_WHOLE_LINK_LIBS "Caffe2_perfkernels_sve_interface")
63+
endif()
64+
4565
# TODO(jiayq): currently, we only implement the very base files for the
4666
# perfkernels. This is because to implement avx and avx2 files, we actually
4767
# need to set up different compilation units and this is a bit more involving

caffe2/perfkernels/common.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ In foo.cc, do:
6161
// we use cpuinfo to identify cpu support and run the proper functions.
6262

6363
#pragma once
64-
65-
#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \
66-
|| defined(CAFFE2_PERF_WITH_AVX)
64+
#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \
65+
defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX)
6766
#include <cpuinfo.h>
6867
#endif
6968

@@ -72,6 +71,18 @@ In foo.cc, do:
7271

7372
#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
7473

74+
#ifdef CAFFE2_PERF_WITH_SVE
75+
#define SVE_DO(funcname, ...) \
76+
{ \
77+
static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \
78+
if (isDo) { \
79+
return funcname##__sve(__VA_ARGS__); \
80+
} \
81+
}
82+
#else // CAFFE2_PERF_WITH_SVE
83+
#define SVE_DO(funcname, ...)
84+
#endif // CAFFE2_PERF_WITH_SVE
85+
7586
#ifdef CAFFE2_PERF_WITH_AVX512
7687
#define AVX512_DO(funcname, ...) \
7788
{ \

caffe2/perfkernels/common_sve.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// This file is here merely to check that the flags are not mixed up: for
2+
// example, if your compiler did not specify -march=armv8-a+sve, you should not
3+
// provide the CAFFE2_PERF_WITH_SVE macro.
4+
5+
#include "caffe2/core/common.h"
6+
7+
#ifdef CAFFE2_PERF_WITH_SVE
8+
#ifndef __ARM_FEATURE_SVE
9+
#error( \
10+
"You found a build system error: CAFFE2_PERF_WITH_SVE is defined" \
11+
"but __ARM_FEATURE_SVE is not defined (via e.g. -march=armv8-a+sve).");
12+
#endif // __ARM_FEATURE_SVE
13+
#endif // CAFFE2_PERF_WITH_SVE
14+
15+
#ifdef __ARM_FEATURE_SVE
16+
#ifndef CAFFE2_PERF_WITH_SVE
17+
#error( \
18+
"You found a build system error: __SVE__ is defined \
19+
(via e.g. -march=armv8-a+sve) " \
20+
"but CAFFE2_PERF_WITH_SVE is not defined.");
21+
#endif // CAFFE2_PERF_WITH_SVE
22+
#endif

caffe2/perfkernels/embedding_lookup_idx.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ static bool EmbeddingLookupGenericSlowIdx(
8888
const int64_t data_size, \
8989
const InType* input, \
9090
const IndexType* indices, \
91-
const IndexType* offsets, \
91+
const IndexType* offsets, \
9292
const float* weights, \
9393
const float* scale_bias, \
9494
bool normalize_by_lengths, \
@@ -113,6 +113,9 @@ static bool EmbeddingLookupGenericSlowIdx(
113113
decltype( \
114114
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
115115
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
116+
decltype( \
117+
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
118+
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__sve; \
116119
bool \
117120
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
118121
const int64_t block_size, \
@@ -121,7 +124,7 @@ static bool EmbeddingLookupGenericSlowIdx(
121124
const int64_t data_size, \
122125
const InType* input, \
123126
const IndexType* indices, \
124-
const IndexType* offsets, \
127+
const IndexType* offsets, \
125128
const float* weights, \
126129
const float* scale_bias, \
127130
bool normalize_by_lengths, \
@@ -131,6 +134,19 @@ static bool EmbeddingLookupGenericSlowIdx(
131134
} else { \
132135
CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
133136
} \
137+
SVE_DO( \
138+
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
139+
block_size, \
140+
output_size, \
141+
index_size, \
142+
data_size, \
143+
input, \
144+
indices, \
145+
offsets, \
146+
weights, \
147+
scale_bias, \
148+
normalize_by_lengths, \
149+
out); \
134150
AVX2_FMA_DO( \
135151
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
136152
block_size, \
@@ -166,7 +182,7 @@ static bool EmbeddingLookupGenericSlowIdx(
166182
const int64_t data_size, \
167183
const InType* input, \
168184
const IndexType* indices, \
169-
const IndexType* offsets, \
185+
const IndexType* offsets, \
170186
const float* weights, \
171187
const float* scale_bias, \
172188
bool normalize_by_lengths, \

0 commit comments

Comments
 (0)