-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathqlinear.h
51 lines (46 loc) · 1.64 KB
/
qlinear.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#pragma once
#include <ATen/Tensor.h>
#include <ATen/Config.h>
namespace at::native {
class QLinearOnednn final {
public:
C10_API static Tensor run_pointwise_tensor(
Tensor act, // int8 CPU tensor, not QTensor
Tensor act_scale,
Tensor act_zero_point,
Tensor onednn_weight, // int8 tensor from MkldnnCPU
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
std::string_view post_op_name,
c10::List<std::optional<at::Scalar>> post_op_args,
std::string_view post_op_algorithm);
C10_API static Tensor run_pointwise_binary_tensor(
Tensor act, // int8 CPU tensor, not QTensor
Tensor act_scale,
Tensor act_zero_point,
Tensor onednn_weight, // int8 tensor from MkldnnCPU
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<at::Tensor> other, // extra input for binary post-op
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
double other_scale,
int64_t other_zero_point,
std::string_view binary_post_op, // e.g. "none", "sum", "add"
double binary_alpha,
std::string_view unary_post_op, // e.g. "none", "relu"
c10::List<std::optional<at::Scalar>> unary_post_op_args,
std::string_view unary_post_op_algorithm);
};
C10_API Tensor _weight_int4pack_mm_cpu_tensor(
const Tensor& A,
const Tensor& B,
const Tensor& qGroupSize,
const Tensor& qScaleAndZeros);
} // namespace at::native