-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathqconv.cpp
2094 lines (1935 loc) · 78.6 KB
/
qconv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include <ATen/core/Tensor.h>
#include <ATen/core/List.h>
#include <ATen/Context.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/SmallVector.h>
#include <ATen/native/quantized/PackedParams.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/XnnpackUtils.h>
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <ATen/native/quantized/ConvUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/quantized/cpu/qconv.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/_empty_affine_quantized_native.h>
#include <ATen/ops/_empty_per_channel_affine_quantized_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/quantize_per_channel_native.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/zeros.h>
#endif
#include <c10/util/irange.h>
namespace {
// To have a sanity check for maximum matrix size.
constexpr int64_t kReasonableMaxDim = 1000000;
} // namespace
template <int kSpatialDim = 2>
static bool ConvDimChecks(
int64_t act_dims,
int64_t stride_dims,
int64_t padding_dims,
int64_t output_padding_dims,
int64_t dilation_dims,
std::string func_name,
bool transpose = false) {
TORCH_CHECK(
act_dims == kSpatialDim + 2,
func_name,
kSpatialDim,
"d(): Expected activation tensor to have ",
kSpatialDim + 2,
" dimensions, got ",
act_dims);
TORCH_CHECK(
stride_dims == kSpatialDim,
func_name,
kSpatialDim,
"d(): Expected stride tensor to have ",
kSpatialDim,
" dimensions, got ",
stride_dims);
TORCH_CHECK(
padding_dims == kSpatialDim,
func_name,
kSpatialDim,
"d(): Expected padding tensor to have ",
kSpatialDim,
" dimensions, got ",
padding_dims);
TORCH_CHECK(
!transpose || (output_padding_dims == kSpatialDim),
func_name,
kSpatialDim,
"d(): Expected output padding tensor to have ",
kSpatialDim,
" dimensions, got ",
output_padding_dims);
TORCH_CHECK(
dilation_dims == kSpatialDim,
func_name,
kSpatialDim,
"d(): Expected dilation tensor to have ",
kSpatialDim,
" dimensions, got ",
dilation_dims);
return true;
}
static inline int64_t compute_deconv_shape(int64_t input,
int64_t kernel,
int64_t stride,
int64_t input_padding,
int64_t output_padding,
int64_t dilation) {
int64_t out = (input - 1) * stride - 2 * input_padding
+ dilation * (kernel - 1) + output_padding + 1;
return out;
}
template <int64_t kSpatialDim>
static at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
int64_t N, int64_t M,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& kernel,
const torch::List<int64_t>& stride,
const torch::List<int64_t>& input_padding,
const torch::List<int64_t>& output_padding,
const torch::List<int64_t>& dilation) {
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
output_shape.resize(kSpatialDim + 2);
output_shape[0] = N; // Batch size
output_shape[1] = M; // Output channels
for (const auto idx : c10::irange(kSpatialDim)) {
output_shape[idx + 2] = compute_deconv_shape(input_shape[idx],
kernel[idx],
stride[idx],
input_padding[idx],
output_padding[idx],
dilation[idx]);
TORCH_CHECK(output_shape[idx + 2] > 0,
"Output dimension is zero for ", idx, " axis;"
" kernel: ", kernel[idx],
", stride: ", stride[idx],
", input padding: ", input_padding[idx],
", output padding: ", output_padding[idx],
", dilation: ", dilation[idx])
TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim,
"Output dimension is beyond reasonable maximum for ", idx,
" axis;"
" kernel: ", kernel[idx],
", stride: ", stride[idx],
", input padding: ", input_padding[idx],
", output padding: ", output_padding[idx],
", dilation: ", dilation[idx]);
}
return output_shape;
}
#ifdef USE_FBGEMM
template <int kSpatialDim = 2>
at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
int N,
int M,
const std::array<int, kSpatialDim>& output_image_shape);
template <>
at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
int N,
int M,
const std::array<int, 2>& output_image_shape) {
return {N, M, output_image_shape[0], output_image_shape[1]};
}
template <>
at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
int N,
int M,
const std::array<int, 3>& output_image_shape) {
return {N,
M,
output_image_shape[0],
output_image_shape[1],
output_image_shape[2]};
}
#endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK
template <size_t kSpatialDim>
static std::array<int64_t, kSpatialDim> MakeInputShape(
int64_t D,
int64_t H,
int64_t W);
template <>
std::array<int64_t, 2> MakeInputShape(int64_t /*D*/, int64_t H, int64_t W) {
return {H, W};
}
template <>
std::array<int64_t, 3> MakeInputShape(int64_t D, int64_t H, int64_t W) {
return {D, H, W};
}
#endif // USE_PYTORCH_QNNPACK
#ifdef USE_FBGEMM
template <int kSpatialDim>
const float* PackedConvWeight<kSpatialDim>::GetBiasData(at::Tensor* bias_ptr) {
const float* bias_data = nullptr;
if (bias.has_value()) {
*bias_ptr = bias.value();
TORCH_CHECK(
bias_ptr->dtype() == at::kFloat,
"[QConv3D] The 'bias' tensor must have 'torch.float' dtype");
*bias_ptr = bias_ptr->contiguous();
TORCH_CHECK(bias_ptr->dim() == 1, "bias should be a vector (1D Tensor)");
const int M = w->outputChannels();
TORCH_CHECK(bias_ptr->size(0) == M, "bias should have ", M, " elements.");
bias_data = bias_ptr->data_ptr<float>();
}
return bias_data;
}
template <int kSpatialDim>
void PackedConvWeight<kSpatialDim>::GetQuantizationParams(
float act_scale,
float out_scale,
std::vector<float>* output_multiplier_float,
std::vector<float>* act_times_w_scale) {
if (q_scheme == c10::kPerTensorAffine) {
*act_times_w_scale = {(act_scale * w_scale[0])};
*output_multiplier_float = {act_times_w_scale->front() / out_scale};
} else if (q_scheme == c10::kPerChannelAffine) {
const int M = w->outputChannels();
output_multiplier_float->resize(M);
act_times_w_scale->resize(M);
for (const auto i : c10::irange(M)) {
act_times_w_scale->at(i) = (act_scale * w_scale[i]);
output_multiplier_float->at(i) = act_times_w_scale->at(i) / out_scale;
}
} else {
TORCH_CHECK(false, "[QConv", kSpatialDim, "D] Unknown quantization scheme");
}
}
template <int kSpatialDim>
at::Tensor PackedConvWeight<kSpatialDim>::apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<false>(input, output_scale, output_zero_point);
}
template <int kSpatialDim>
at::Tensor PackedConvWeight<kSpatialDim>::apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) {
return apply_impl<true>(input, output_scale, output_zero_point);
}
template <int kSpatialDim>
template <bool kReluFused>
at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point) {
// Quantized kernels are all written with NHWC (channels last) layout in
// mind. Ideally, we'd be compatible with conv2d behavior and preserve the
// inputs layout as is (doing necessary upconversions).
//
// However, to be more robust, for now we just force output layout to always
// be NHWC (channels last), thus opportunistically improving perf.
//
// This might change when full memory format support lands
// See https://github.com/pytorch/pytorch/issues/23403
const std::string func_name = transpose() ? "quantized::conv_transpose"
: "quantized::conv";
TORCH_CHECK(
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
TORCH_CHECK(act.scalar_type() == c10::kQUInt8,
func_name,
"(FBGEMM): Expected activation data type ",
toString(c10::kQUInt8),
" but got ",
toString(act.scalar_type()));
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());
const int N = act.size(0);
const int C = act.size(1);
const int D = kSpatialDim == 2 ? 1 : act.size(2);
const int H = act.size(kSpatialDim);
const int W = act.size(kSpatialDim + 1);
const at::Tensor act_ndhwc = kSpatialDim == 2
? act.contiguous(c10::MemoryFormat::ChannelsLast)
: at::native::fbgemm_utils::ConvertToChannelsLast3dTensor(act);
const uint8_t* act_data =
reinterpret_cast<uint8_t*>(act_ndhwc.data_ptr<c10::quint8>());
auto* pack_w = w.get();
const int M = pack_w->outputChannels();
const int kernel_d = kSpatialDim == 2 ? 1 : kernel[0];
const int kernel_h = kernel[kSpatialDim - 2];
const int kernel_w = kernel[kSpatialDim - 1];
const int pad_d = kSpatialDim == 2 ? 0 : padding_[0];
const int pad_h = padding_[kSpatialDim - 2];
const int pad_w = padding_[kSpatialDim - 1];
const int stride_d = kSpatialDim == 2 ? 1 : stride_[0];
const int stride_h = stride_[kSpatialDim - 2];
const int stride_w = stride_[kSpatialDim - 1];
const int dilation_d = kSpatialDim == 2 ? 1 : dilation_[0];
const int dilation_h = dilation_[kSpatialDim - 2];
const int dilation_w = dilation_[kSpatialDim - 1];
const int output_padding_d = kSpatialDim == 2 ? 0 : output_padding_[0];
const int output_padding_h = output_padding_[kSpatialDim - 2];
const int output_padding_w = output_padding_[kSpatialDim - 1];
if (kSpatialDim == 2) {
TORCH_CHECK(
C == pack_w->inputChannels(),
"[QConv2D] Given groups=",
groups_,
", weight of size ",
M,
", ",
kernel_h,
", ",
kernel_w,
", ",
pack_w->inputChannels(),
", expected input (NCHW) ",
N,
", ",
C,
", ",
H,
", ",
W,
" to have ",
pack_w->inputChannels(),
" channels, but got ",
C,
" channels instead");
} else {
TORCH_CHECK(
C == pack_w->inputChannels(),
"[QConv3D] Given groups=",
groups_,
", weight of size ",
M,
", ",
kernel_d,
", ",
kernel_h,
", ",
kernel_w,
", ",
pack_w->inputChannels(),
", expected input (NCDHW) ",
N,
", ",
C,
", ",
D,
", ",
H,
", ",
W,
" to have ",
pack_w->inputChannels(),
" channels, but got ",
C,
" channels instead");
}
fbgemm::conv_param_t<kSpatialDim> conv_p =
at::native::fbgemm_utils::MakeFbgemmConvParam<kSpatialDim>(
N, // Batch size
C, // Number of input channels
M, // Number of output channels
kSpatialDim == 2 ? std::vector<int>{H, W} : std::vector<int>{D, H, W},
groups_,
kSpatialDim == 2 ? std::vector<int>{kernel_h, kernel_w}
: std::vector<int>{kernel_d, kernel_h, kernel_w},
kSpatialDim == 2 ? std::vector<int>{stride_h, stride_w}
: std::vector<int>{stride_d, stride_h, stride_w},
kSpatialDim == 2 ? std::vector<int>{pad_h, pad_w}
: std::vector<int>{pad_d, pad_h, pad_w},
kSpatialDim == 2
? std::vector<int>{dilation_h, dilation_w}
: std::vector<int>{dilation_d, dilation_h, dilation_w},
kSpatialDim == 2
? std::vector<int>{output_padding_h, output_padding_w}
: std::vector<int>{output_padding_d,
output_padding_h,
output_padding_w},
transpose());
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const float act_scale = act.q_scale();
const int32_t act_zero_point = act.q_zero_point();
at::Tensor bias;
const float* bias_data = GetBiasData(&bias);
TORCH_CHECK(
w_scale.size() == w_zp.size(),
"Weight scales and zero points vectors should have the same size.");
std::vector<float> output_multiplier_float;
std::vector<float> act_times_w_scale;
GetQuantizationParams(
act_scale, output_scale, &output_multiplier_float, &act_times_w_scale);
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
if (transpose()) {
output_shape = MakeDeConvOutputShape<kSpatialDim>(
N,
M,
kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
kernel,
stride(),
padding(),
output_padding(),
dilation());
// if use direct convolution implementation, compute the col_offsets
// of the weight matrix at model initialization stage.
// We need to know the shape of output matrix
// to compute col_offsets for direct convolution.
// Hence it cannot be called from inside weight packing function
// like other quantized conv implementation
if (pack_w->getPackedWForDirectconv().get() &&
pack_w->getPackedWForDirectconv().get()->is_first_call()) {
pack_w->getPackedWForDirectconv().get()->col_offsets_with_zero_pt_s8acc32_DirectConvT(
conv_p,
w_zp.data(),
col_offsets,
M);
}
} else {
output_shape = MakeConvOutputShape<kSpatialDim>(N, M, conv_p.OUT_DIM);
}
if (N > 0) {
TORCH_CHECK(
std::all_of(
output_shape.begin(),
output_shape.end(),
[](int64_t i) { return i > 0; }),
"[QConv",
kSpatialDim,
"D] each dimension of output tensor should be greater than 0");
}
at::Tensor output = kSpatialDim == 2
? at::_empty_affine_quantized(
output_shape,
device(c10::kCPU)
.dtype(c10::kQUInt8)
.memory_format(c10::MemoryFormat::ChannelsLast),
output_scale,
output_zero_point,
std::nullopt)
: at::native::fbgemm_utils::MakeEmptyAffineQuantizedChannelsLast3dTensor(
output_shape[0],
output_shape[1],
output_shape[2],
output_shape[3],
output_shape[4],
device(c10::kCPU).dtype(c10::kQUInt8),
output_scale,
output_zero_point);
at::Tensor buffer =
at::empty(output.sizes(), output.options().dtype(c10::kInt));
const int num_tasks = at::get_num_threads();
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
fbgemm::DoNothing<> kNoOpObj{};
for (const auto task_id : c10::irange(begin, end)) {
if (q_scheme == c10::kPerTensorAffine) {
fbgemm::ReQuantizeOutput<
kReluFused,
fbgemm::QuantizationGranularity::TENSOR,
float>
output_proc_obj(
kNoOpObj,
output_multiplier_float.data(),
output_zero_point,
act_zero_point,
w_zp.data(),
nullptr, /* row offset buffer */
col_offsets.data(),
bias_data,
M,
groups_,
act_times_w_scale.data());
fbgemm::fbgemmConv<decltype(output_proc_obj), kSpatialDim, int32_t>(
conv_p,
act_data,
*pack_w,
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
buffer.data_ptr<int32_t>(),
output_proc_obj,
task_id /* thread_id*/,
num_tasks /* num_threads */);
} else if (q_scheme == c10::kPerChannelAffine) {
fbgemm::ReQuantizeOutput<
kReluFused,
fbgemm::QuantizationGranularity::OUT_CHANNEL,
float>
output_proc_obj(
kNoOpObj,
output_multiplier_float.data(),
output_zero_point,
act_zero_point,
w_zp.data(),
nullptr, /* row offset buffer */
col_offsets.data(),
bias_data,
M,
groups_,
act_times_w_scale.data());
fbgemm::fbgemmConv<decltype(output_proc_obj), kSpatialDim, int32_t>(
conv_p,
act_data,
*pack_w,
reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
buffer.data_ptr<int32_t>(),
output_proc_obj,
task_id /* thread_id*/,
num_tasks /* num_threads */);
}
}
});
return output;
}
template at::Tensor PackedConvWeight<2>::apply(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeight<2>::apply_relu(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeight<3>::apply(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeight<3>::apply_relu(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeight<2>::apply_impl<false>(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
template at::Tensor PackedConvWeight<3>::apply_impl<false>(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point);
#endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK
#ifdef USE_XNNPACK
template <int kSpatialDim>
template <typename scalar_t, bool kReluFused>
at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl_xnnp(
const at::Tensor& act, double output_scale, int64_t output_zero_point) {
using underlying_t = typename scalar_t::underlying;
std::lock_guard<std::mutex> lock(qnnp_mutex_);
const std::string func_name = transpose()
? "quantized::conv_transpose (xnnpack)"
: "quantized::conv (xnnpack)";
TORCH_CHECK(
kSpatialDim == 2,
func_name, ": xnnpack does not currently support 3d convolution.");
/*
* NB:
* [de]conv_prepack prepares weights (values, scale, and zero_points) ahead of
* time during prepack() call assuming the activation will be uint8_t. But it
* may not always be the case. A solution may involve making prepack routine
* aware of the input qdtype. But currently all the pieces are not ready to
* pass that model level info to the prepack function. So, for now, here in
* this function we have to massage weights if we learn the input qdtype is
* not uint8_t. This involves copying and converting uint8_t to int8_t
* whenever necessary. To add to that, since XNNPACK, as of writing this,
* doesn't support per_channel weights for quint8_t, we add following assert
* makes sure we don't run into that case. Also take shortcuts when processing
* weights, which means we have to revisit and fix some weight massging logic
* when we enable the missing feature in XNNPACK.
*
* Table below summarizes how the weights are handled,
*
* .-------------------------------------------------------------------------.
* | input_qdtype | uint8_t | int8_t |
* | per_channel | yes | no | yes | no |
* |-------------------------------------------------------------------------|
* | zero_points | at::zeros()* | orig_zp + 128 | at:zeros()** | orig_zp |
* | scale | dtype = float, no changes needed |
* | values | always processed before passing to XNNPACK |
* .-------------------------------------------------------------------------.
*
* Notes: * - zero_points for uint8_t + per_channel: no support in xnnpack, need
* to fix when support is added. ** - zero_points for int8_t: symmetric
* quantization means XNNPACK will ignore kernel zero point(s).
*/
if constexpr (std::is_same_v<underlying_t, c10::quint8>) {
TORCH_CHECK(!per_channel(),
func_name, ": xnnpack does not currently have per_channel support with activation dtype of c10::quint8."
);
}
// More checks
ConvDimChecks<kSpatialDim>(
act.ndimension(),
stride().size(),
padding().size(),
output_padding().size(),
dilation().size(),
func_name,
transpose());
const int64_t N = act.size(0);
const int64_t H = act.size(2);
const int64_t W = act.size(3);
const int64_t D = 1;
const int64_t M = bias.size(0);
const auto act_nhwc = act.contiguous(c10::MemoryFormat::ChannelsLast);
const auto act_input_scale = act_nhwc.q_scale();
auto status = xnn_status_invalid_state;
// Create an operator iff necessary
if (!xnnp_convolution_op ||
(!input_scale.has_value() || input_scale.value() != act_input_scale)) {
xnn_operator_t xnnp_op = nullptr;
// Update the input scale so we may cache the op
input_scale = act_input_scale;
// create an empty tensor for packing the weights
const at::Tensor weight_contig =
orig_weight.contiguous(c10::MemoryFormat::ChannelsLast);
const float* w_scales_data = w_scales.const_data_ptr<float>();
underlying_t w_zp = 0;
at::Tensor weight_tensor;
if (!per_channel()) {
w_zp = static_cast<underlying_t>(
weight_contig.q_zero_point() +
(std::is_same_v<underlying_t, uint8_t> ? 128 : 0));
weight_tensor = at::native::empty_affine_quantized(
weight_contig.sizes(),
c10::CppTypeToScalarType<scalar_t>::value,
std::nullopt /* layout */,
c10::kCPU,
std::nullopt /* pin_memory */,
w_scales_data[0],
w_zp,
c10::MemoryFormat::ChannelsLast);
} else { /* per_channel */
weight_tensor = at::native::empty_per_channel_affine_quantized(
weight_contig.sizes(),
w_scales,
at::zeros(w_scales.sizes(), at::kInt), /* see comment above about w_zp */
weight_contig.q_per_channel_axis(),
c10::CppTypeToScalarType<scalar_t>::value,
std::nullopt /* layout */,
c10::kCPU,
std::nullopt /* pin_memory */,
c10::MemoryFormat::ChannelsLast);
}
// copy from the original weight and take care of dtype change if necessary
at::native::xnnp_utils::q8_copy_int8_weight_and_add_offset<scalar_t>(
weight_contig, weight_tensor);
const at::Tensor xnnp_weight =
at::native::xnnp_utils::convert_conv_weights_to_channel_last_tensor<
kSpatialDim>(weight_tensor, groups(), transpose());
auto output_min = kReluFused
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).first
: std::numeric_limits<underlying_t>::min();
auto output_max = kReluFused
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
? activationLimits<underlying_t>(output_scale, output_zero_point, Activation::RELU).second
: std::numeric_limits<underlying_t>::max();
// Original bias was float, so we requantize it here.
at::Tensor qbias = quant_utils::QuantizeBias(per_channel(), bias, weight_contig, act_input_scale);
status = at::native::xnnp_utils::xnnp_create_convolution2d_nhwc(
padding()[0],
padding()[1],
padding()[0],
padding()[1],
kernel_[0],
kernel_[1],
stride()[0],
stride()[1],
dilation()[0],
dilation()[1],
groups(),
!transpose() ? orig_weight.size(1) : orig_weight.size(0) / groups(),
!transpose() ? orig_weight.size(0) / groups() : orig_weight.size(1),
!transpose() ? orig_weight.size(1) * groups() : orig_weight.size(0),
!transpose() ? orig_weight.size(0) : orig_weight.size(1) * groups(),
act_nhwc.q_zero_point(),
act_input_scale,
w_zp, /* will be ignored for Q[SC]8, see comment
above about w_zp*/
w_scales_data,
reinterpret_cast<const underlying_t*>(
xnnp_weight.template data_ptr<scalar_t>()),
reinterpret_cast<int32_t*>(qbias.template data_ptr<c10::qint32>()),
output_zero_point,
output_scale,
output_min,
output_max,
0,
&xnnp_op,
per_channel(),
transpose());
xnnp_convolution_op = xnnpack_operator(xnnp_op);
TORCH_CHECK(
status == xnn_status_success,
func_name,
": xnn create operator failed(",
status,
")");
}
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
const auto input_shape = MakeInputShape<kSpatialDim>(D, H, W);
if (transpose()) {
output_shape = MakeDeConvOutputShape<kSpatialDim>(
N, M, {H, W}, kernel_, stride(), padding(), output_padding(), dilation());
} else {
output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(
N, M, input_shape, kernel_, stride(), padding(), dilation());
}
if (act_nhwc.numel() > 0) {
TORCH_CHECK(
std::all_of(
output_shape.begin(),
output_shape.end(),
[](int64_t i) { return i > 0; }),
func_name, ": ", kSpatialDim, "d (xnnpack): each dimension of output tensor should be greater than 0.")
}
// Allocate output Tensor and a buffer for XNNPACK to use
at::Tensor output = at::native::empty_affine_quantized(
output_shape,
c10::CppTypeToScalarType<scalar_t>::value,
std::nullopt /* layout */,
c10::kCPU,
std::nullopt /* pin_memory */,
output_scale,
output_zero_point,
c10::MemoryFormat::ChannelsLast);
// Reshape the operator
status = at::native::xnnp_utils::xnnp_reshape_convolution2d_nhwc(
xnnp_convolution_op.get(),
N,
H,
W,
caffe2::pthreadpool_(),
per_channel(),
transpose(),
output_padding()[0],
output_padding()[1]);
TORCH_CHECK(
status == xnn_status_success,
func_name,
": xnn setup operator failed(",
status,
")");
// Setup the operator
status = at::native::xnnp_utils::xnnp_setup_convolution2d_nhwc(
xnnp_convolution_op.get(),
reinterpret_cast<const underlying_t*>(act_nhwc.template data_ptr<scalar_t>()),
reinterpret_cast<underlying_t*>(output.template data_ptr<scalar_t>()),
per_channel(),
transpose());
TORCH_CHECK(
status == xnn_status_success,
func_name,
": xnn setup operator failed(",
status,
")");
// Run the operator
status = xnn_run_operator(
xnnp_convolution_op.get(), /* xnn_operator_t op */
caffe2::pthreadpool_()); /* pthreadpool_t threadpool */
TORCH_CHECK(
status == xnn_status_success,
func_name,
": xnn run operator failed(",
status,
")");
return output;
}
#endif // USE_XNNPACK
template <int kSpatialDim>
template <bool kReluFused>
at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point) {
// QNNPack is not thread safe
std::lock_guard<std::mutex> lock(qnnp_mutex_);
const std::string func_name = transpose() ? "quantized::conv_transpose"
: "quantized::conv";
TORCH_CHECK(!(kReluFused && transpose()),
kSpatialDim == 2,
func_name, kSpatialDim,
"d (qnnpack): ConvTranspose cannot be fused with ReLU.");
TORCH_CHECK(act.scalar_type() == c10::kQUInt8,
func_name,
"(qnnpack): Expected activation data type ",
toString(c10::kQUInt8),
" but got ",
toString(act.scalar_type()));
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());
auto* pack_w = w.get();
// TODO Can be replaced with packB->getOutputChannels() when update pre-pack
// to actually do the packing.
const int out_ch_idx = transpose() ? 1 : 0;
const auto out_ch = bias.size(0);
// inputs are in semantic NCHW format
const int N = act.size(0);
const int C = act.size(1);
const int D = kSpatialDim == 3 ? act.size(2) : 1;
const int H = act.size(kSpatialDim);
const int W = act.size(kSpatialDim + 1);
const int M = out_ch; // output channels
const auto channels_last = kSpatialDim == 2
? c10::MemoryFormat::ChannelsLast
: c10::MemoryFormat::ChannelsLast3d;
const at::Tensor act_ndhwc = act.contiguous(channels_last);
auto output_min = kReluFused
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
.first
: std::numeric_limits<uint8_t>::min();
auto output_max = kReluFused
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
? activationLimits<uint8_t>(output_scale, output_zero_point, Activation::RELU)
.second
: std::numeric_limits<uint8_t>::max();
double act_input_scale = act_ndhwc.q_scale();
// Re-quantizing the bias based on input scale and weight scale.
if (!input_scale.has_value() || input_scale.value() != act_input_scale) {
TORCH_CHECK(M == (transpose() ? groups() : 1) * orig_weight.size(out_ch_idx),
"Output channel size of weight and bias must match.");
TORCH_CHECK(C == (transpose() ? 1 : groups()) * orig_weight.size(1 - out_ch_idx),
"Input channel size of weight and bias must match.");
// Get the original weight and adjust it to uint8 from int8
auto weight_contig = orig_weight.contiguous(channels_last);
auto bias_fp32 = bias;
int8_t* w_data =
reinterpret_cast<int8_t*>(weight_contig.template data_ptr<c10::qint8>());
float* weight_scales_data = w_scales.data_ptr<float>();
// We calculate requant scale here as the vector holding the requant scale
// is owned by this module. The pointer is then passed to qnnpack backend.
generate_requantization_scales(
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
w_scales, act_input_scale, output_scale, requantization_scales);
// TODO Kimish, we are allocating affine_quantized regardless of per channel or not.
// This allocation is actually used only for packing weight and thus will be freed.
// Still we should be consistent. Fix this.
at::Tensor qnnp_weight = at::_empty_affine_quantized(
weight_contig.sizes(),
at::device(c10::kCPU).dtype(c10::kQUInt8).memory_format(channels_last),
weight_scales_data[0],
w_zero_points[0],
std::nullopt);
auto* qnnp_w_data = qnnp_weight.template data_ptr<c10::quint8>();
auto wt_numel = weight_contig.numel();
for (const auto i : c10::irange(wt_numel)) {
qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
}
// Original bias was float, so we requantize it here.
at::Tensor qbias = quant_utils::QuantizeBias(convolution_op->per_channel, bias_fp32, weight_contig, act_input_scale);
// Update the input scale to not pack again.
input_scale = act_input_scale;
w.reset();
w = std::make_unique<qnnpack::PrePackConvWeights>(
convolution_op.get(),
w_zero_points.data(),
reinterpret_cast<uint8_t*>(qnnp_w_data),
reinterpret_cast<int32_t*>(qbias.template data_ptr<c10::qint32>()));
pack_w = w.get();
if (at::globalContext().releaseWeightsWhenPrepacking()) {
// On mobile, we release the original weight by resetting the intrusive_ptr.
// Calling unpack after this will throw an assertion.
orig_weight.reset();
}
// Set padding buffer to zero point. This can only be done if we want
// to do it only once.
if (zero_buffer_size) {
memset(
convolution_op->zero_buffer,
act_ndhwc.q_zero_point(),
zero_buffer_size);
}
}
TORCH_INTERNAL_ASSERT(pack_w != nullptr, "Packed Weights are NULL");
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
const auto input_shape = MakeInputShape<kSpatialDim>(D, H, W);
if (transpose()) {
output_shape = MakeDeConvOutputShape<kSpatialDim>(
N,
M,
kSpatialDim == 2 ? std::vector<int64_t>{H, W} : std::vector<int64_t>{D, H, W},
kernel_,
stride(),
padding(),
output_padding(),
dilation());
} else {
output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(
N, M, input_shape, kernel_, stride(), padding(), dilation());
}
if (act_ndhwc.numel() > 0) {
TORCH_CHECK(
std::all_of(
output_shape.begin(),
output_shape.end(),
[](int64_t i) { return i > 0; }),
func_name,
kSpatialDim,
"d (qnnpack): each dimension of output tensor should "
"be greater than 0.")
}
// Allocate output Tensor and a buffer for QNNPACK to use
at::Tensor output = at::native::empty_affine_quantized(
output_shape,
c10::kQUInt8,
std::nullopt /* layout */,
c10::kCPU,
std::nullopt /* pin_memory */,
output_scale,
output_zero_point,
channels_last);
pytorch_qnnp_status run_status{};
if (transpose()) {
run_status = qnnpack::qnnpackDeConv(
convolution_op.get(),
pack_w->getPackedWeights(),
N,
H,
W,
act_ndhwc.q_zero_point(),
reinterpret_cast<uint8_t*>(act_ndhwc.template data_ptr<c10::quint8>()),
w_zero_points.data(),
requantization_scales.data(),
output.q_zero_point(),
output_min,
output_max,
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
caffe2::pthreadpool_());