-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathqelu.cpp
36 lines (29 loc) · 1.25 KB
/
qelu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/ivalue.h>
#include <torch/library.h>
#include <ATen/native/quantized/cpu/QuantizedOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_empty_affine_quantized.h>
#endif
namespace at::native {
DEFINE_DISPATCH(qelu_stub);
static Tensor quantized_elu(
const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(), output_scale, output_zero_point);
qelu_stub(qx.device().type(), qx, alpha, scale, input_scale, qy);
return qy;
}
static Tensor quantized_celu(const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha) {
TORCH_CHECK(alpha.to<double>() != 0,
"ZeroDivisionError: alpha cannot be 0 for CELU");
double inv_alpha = 1. / alpha.to<double>();
return quantized_elu(qx, output_scale, output_zero_point, alpha, Scalar(1.0), Scalar(inv_alpha));
}
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::elu"), quantized_elu);
m.impl(TORCH_SELECTIVE_NAME("quantized::celu"), quantized_celu);
}
} // namespace at::native