-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathChannelShuffle.cpp
122 lines (105 loc) · 3.76 KB
/
ChannelShuffle.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_empty_affine_quantized_native.h>
#include <ATen/ops/channel_shuffle_native.h>
#endif
namespace at::native {
#ifdef USE_PYTORCH_QNNPACK
namespace {
Tensor quantized_channel_shuffle_impl(
const Tensor& self,
int64_t groups) {
TORCH_CHECK(
groups > 0,
"Number of groups to divide channels in must be positive.",
" Value of groups:", groups);
TORCH_CHECK(
self.dim() == 4,
"channel_shuffle expects 4D input, but got input with sizes ",
self.sizes());
TORCH_CHECK(
self.scalar_type() == kQUInt8,
"Quantized channel shuffle works only on ",
toString(c10::kQUInt8),
" but got ", self.scalar_type());
const Tensor self_nhwc = self.contiguous(MemoryFormat::ChannelsLast);
Tensor qy = at::native::empty_affine_quantized(
self_nhwc.sizes(),
kQUInt8,
std::nullopt /* layout */,
kCPU,
std::nullopt /* pin_memory */,
self_nhwc.q_scale(),
self_nhwc.q_zero_point(),
MemoryFormat::ChannelsLast);
// Degenerate case of just copying.
if (groups == 1) {
qy.copy_(self_nhwc);
return qy.contiguous(self.suggest_memory_format());
}
int64_t channels = self.size(1);
TORCH_CHECK(channels > 0,
"Number of channels must be positive, got:", channels);
TORCH_CHECK((channels % groups) == 0,
"Number of channels must be divisible gy groups. Got ",
channels, " channels and ", groups, " groups.");
initQNNPACK();
pytorch_qnnp_operator_t qnnpack_operator{nullptr};
const pytorch_qnnp_status createStatus = pytorch_qnnp_create_channel_shuffle_nc_x8(
groups /* groups */,
channels / groups /* group channels */,
0 /* flags */,
&qnnpack_operator);
TORCH_INTERNAL_ASSERT(
createStatus == pytorch_qnnp_status_success,
"failed to create QNNPACK ChannelShuffle operator");
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
qnnpack_uniq_ptr(qnnpack_operator);
const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_channel_shuffle_nc_x8(
qnnpack_uniq_ptr.get(),
self_nhwc.numel() / channels /* batch size */,
(uint8_t*)self_nhwc.data_ptr<c10::quint8>() /* self data */,
channels /* self stride */,
(uint8_t*)qy.data_ptr<c10::quint8>() /* qy data */,
channels /* qy stride */);
TORCH_INTERNAL_ASSERT(
setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK ChannelShuffle operator");
pthreadpool_t threadpool = caffe2::pthreadpool_();
const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(qnnpack_operator, threadpool);
TORCH_INTERNAL_ASSERT(
runStatus == pytorch_qnnp_status_success,
"failed to run QNNPACK ChannelShuffle operator");
return qy.contiguous(self.suggest_memory_format());
}
} // namespace
#endif
// at::native functions for the native_functions.yaml
Tensor channel_shuffle_quantized_cpu(
const Tensor& self,
int64_t groups) {
#ifdef USE_PYTORCH_QNNPACK
return quantized_channel_shuffle_impl(self, groups);
#endif
// If QNNPACK is not available then fall back to the
// non quantized path.
return at::native::channel_shuffle(self, groups);
}
// Keep the registry in the anonymous namespace.
namespace {
class QChannelShuffle final : public c10::OperatorKernel {
public:
Tensor operator()(Tensor qx, int64_t groups) {
return channel_shuffle_quantized_cpu(qx, groups);
}
};
} // namespace
} // namespace at::native