-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathSorting.cpp
66 lines (58 loc) · 1.87 KB
/
Sorting.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/SortingUtils.h>
#include <ATen/native/quantized/cpu/QuantizedOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/topk_native.h>
#endif
namespace at::native {
// Currently internal-only.
//
// This implementation assumes the quantizer for the input and the out-
// put are the same.
//
// If we want to support this publicly, we need to add
// a requantization step to the kernel.
static std::tuple<Tensor&, Tensor&> quantized_topk_out_cpu(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
_allocate_or_resize_output_with_indices(values, indices, self, dim_, k);
qtopk_stub(kCPU, values, indices, self, k, dim, largest, sorted);
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor, Tensor> topk_quantized_cpu(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
auto qscheme = self.qscheme();
TORCH_CHECK(
qscheme == QScheme::PER_TENSOR_AFFINE ||
qscheme == QScheme::PER_TENSOR_SYMMETRIC,
"Top-K is only supported on per-tensor quantization");
Tensor values = at::_empty_affine_quantized(
{0},
self.options(),
self.q_scale(),
self.q_zero_point());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
return quantized_topk_out_cpu(values, indices, self, k, dim, largest, sorted);
}
DEFINE_DISPATCH(qtopk_stub);
} // namespace at::native