@@ -37,6 +37,8 @@ inline void _vec_log_softmax_lastdim(
37
37
int64_t outer_size,
38
38
int64_t dim_size) {
39
39
using Vec = vec::Vectorized<vec::vec_scalar_t <scalar_t >>;
40
+ using acc_t = at::opmath_type<scalar_t >;
41
+ using AccVec = vec::Vectorized<vec::vec_scalar_t <acc_t >>;
40
42
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
41
43
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
42
44
// nowadays, so maybe in the future, we can leverage the knowledge of a
@@ -54,8 +56,8 @@ inline void _vec_log_softmax_lastdim(
54
56
parallel_for (0 , outer_size, 0 , [&](int64_t begin, int64_t end) {
55
57
// MSVC requires such a declaration of dynamic arrays
56
58
// Source: https://stackoverflow.com/a/33423538
57
- auto tmp_sum_scalar = std::make_unique<scalar_t []>(CHUNK_SIZE);
58
- auto max_input_arr = std::make_unique<scalar_t []>(CHUNK_SIZE);
59
+ auto tmp_sum_scalar = std::make_unique<acc_t []>(CHUNK_SIZE);
60
+ auto max_input_arr = std::make_unique<acc_t []>(CHUNK_SIZE);
59
61
for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
60
62
int64_t loop_end = CHUNK_SIZE;
61
63
if (ii + CHUNK_SIZE > end)
@@ -71,35 +73,35 @@ inline void _vec_log_softmax_lastdim(
71
73
for (const auto j : c10::irange (loop_end)) {
72
74
int64_t i = ii + j;
73
75
const scalar_t * input_data = input_data_base + i * dim_size;
74
- scalar_t max_input = max_input_arr[j];
76
+ acc_t max_input = max_input_arr[j];
75
77
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t >(
76
- [max_input](Vec x) { return (x - Vec (max_input)).exp (); },
77
- [](Vec x, Vec y) { return x + y; },
78
+ [max_input](AccVec x) { return (x - AccVec (max_input)).exp (); },
79
+ [](AccVec x, AccVec y) { return x + y; },
78
80
input_data,
79
81
dim_size);
80
82
}
81
83
// See [Note AVX-SSE transitions] for why this should call the
82
84
// vectorized version (aside from perf improvements).
83
85
vec::map (
84
- [](Vec x) { return x.log (); },
86
+ [](AccVec x) { return x.log (); },
85
87
tmp_sum_scalar.get (),
86
88
tmp_sum_scalar.get (),
87
89
loop_end);
88
90
for (const auto j : c10::irange (loop_end)) {
89
91
int64_t i = ii + j;
90
92
const scalar_t * input_data = input_data_base + i * dim_size;
91
93
scalar_t * output_data = output_data_base + i * dim_size;
92
- scalar_t tmp_sum = tmp_sum_scalar[j];
93
- scalar_t max_input = max_input_arr[j];
94
+ acc_t tmp_sum = tmp_sum_scalar[j];
95
+ acc_t max_input = max_input_arr[j];
94
96
95
97
// It's necessary to keep the order of the operations below.
96
98
// In some cases that input is large digits and the difference
97
99
// is small, if we compute `max_input` plus `tmp_sum` before,
98
100
// there would be a numerical problem. See an example in
99
101
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
100
102
vec::map (
101
- [tmp_sum, max_input](Vec x) {
102
- return x - Vec (max_input) - Vec (tmp_sum);
103
+ [tmp_sum, max_input](AccVec x) {
104
+ return x - AccVec (max_input) - AccVec (tmp_sum);
103
105
},
104
106
output_data,
105
107
input_data,
0 commit comments