Skip to content

Commit 1e958e5

Browse files
committed
Using acc_t for log_softmax
ghstack-source-id: dfc94f0 Pull Request resolved: #143896
1 parent 99dd846 commit 1e958e5

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

aten/src/ATen/native/cpu/SoftMaxKernel.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ inline void _vec_log_softmax_lastdim(
3737
int64_t outer_size,
3838
int64_t dim_size) {
3939
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
40+
using acc_t = at::opmath_type<scalar_t>;
41+
using AccVec = vec::Vectorized<vec::vec_scalar_t<acc_t>>;
4042
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
4143
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
4244
// nowadays, so maybe in the future, we can leverage the knowledge of a
@@ -54,8 +56,8 @@ inline void _vec_log_softmax_lastdim(
5456
parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
5557
// MSVC requires such a declaration of dynamic arrays
5658
// Source: https://stackoverflow.com/a/33423538
57-
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(CHUNK_SIZE);
58-
auto max_input_arr = std::make_unique<scalar_t[]>(CHUNK_SIZE);
59+
auto tmp_sum_scalar = std::make_unique<acc_t[]>(CHUNK_SIZE);
60+
auto max_input_arr = std::make_unique<acc_t[]>(CHUNK_SIZE);
5961
for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
6062
int64_t loop_end = CHUNK_SIZE;
6163
if (ii + CHUNK_SIZE > end)
@@ -71,35 +73,35 @@ inline void _vec_log_softmax_lastdim(
7173
for (const auto j : c10::irange(loop_end)) {
7274
int64_t i = ii + j;
7375
const scalar_t* input_data = input_data_base + i * dim_size;
74-
scalar_t max_input = max_input_arr[j];
76+
acc_t max_input = max_input_arr[j];
7577
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
76-
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
77-
[](Vec x, Vec y) { return x + y; },
78+
[max_input](AccVec x) { return (x - AccVec(max_input)).exp(); },
79+
[](AccVec x, AccVec y) { return x + y; },
7880
input_data,
7981
dim_size);
8082
}
8183
// See [Note AVX-SSE transitions] for why this should call the
8284
// vectorized version (aside from perf improvements).
8385
vec::map(
84-
[](Vec x) { return x.log(); },
86+
[](AccVec x) { return x.log(); },
8587
tmp_sum_scalar.get(),
8688
tmp_sum_scalar.get(),
8789
loop_end);
8890
for (const auto j : c10::irange(loop_end)) {
8991
int64_t i = ii + j;
9092
const scalar_t* input_data = input_data_base + i * dim_size;
9193
scalar_t* output_data = output_data_base + i * dim_size;
92-
scalar_t tmp_sum = tmp_sum_scalar[j];
93-
scalar_t max_input = max_input_arr[j];
94+
acc_t tmp_sum = tmp_sum_scalar[j];
95+
acc_t max_input = max_input_arr[j];
9496

9597
// It's necessary to keep the order of the operations below.
9698
// In some cases that input is large digits and the difference
9799
// is small, if we compute `max_input` plus `tmp_sum` before,
98100
// there would be a numerical problem. See an example in
99101
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
100102
vec::map(
101-
[tmp_sum, max_input](Vec x) {
102-
return x - Vec(max_input) - Vec(tmp_sum);
103+
[tmp_sum, max_input](AccVec x) {
104+
return x - AccVec(max_input) - AccVec(tmp_sum);
103105
},
104106
output_data,
105107
input_data,

test/test_nn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13126,6 +13126,20 @@ def test_channel_shuffle(self, device):
1312613126
y = y.contiguous(memory_format=torch.contiguous_format)
1312713127
self.assertEqual(y, y_ref)
1312813128

13129+
@onlyCUDA
13130+
def test_log_softmax_acc(self):
13131+
# log_softmax will use high precision as accumulate type for CPU,
13132+
# which is same as CUDA behavior. Fix https://github.com/pytorch/pytorch/issues/140222
13133+
input = torch.randn(2, 2, 2, 2, 2, dtype=torch.bfloat16)
13134+
input_fp32 = input.clone().to(torch.float32)
13135+
res = F.log_softmax(input, dim=1)
13136+
res_fp32 = F.log_softmax(input_fp32, dim=1)
13137+
self.assertEqual(res, res_fp32.to(torch.bfloat16))
13138+
13139+
input_cuda = input.clone().cuda()
13140+
res_cuda = F.log_softmax(input_cuda, dim=1)
13141+
self.assertEqual(res, res_cuda)
13142+
1312913143

1313013144
class TestFunctionalPickle(TestCase):
1313113145

0 commit comments

Comments
 (0)