Skip to content

Commit 33a1996

Browse files
yanbing-jpytorchmergebot
authored andcommitted
Fix perf downgrad by reverting template use in use_mkldnn_matmul (#159024)
This PR is to fix the performance downgrad by reverting template use in `use_mkldnn_matmul` in #157520 . Fix #159031 and #159551. Pull Request resolved: #159024 Approved by: https://github.com/mingfeima
1 parent ee62177 commit 33a1996

File tree

1 file changed

+56
-39
lines changed

1 file changed

+56
-39
lines changed

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/Config.h>
33
#include <ATen/Context.h>
4-
#include <ATen/Dispatch.h>
54
#include <ATen/core/Tensor.h>
65
#include <ATen/native/mkldnn/Matmul.h>
76

@@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
428427
}
429428
}
430429

431-
template <typename T>
432-
bool use_mkldnn_typed_matmul(
430+
bool use_mkldnn_bf16_matmul(
433431
const Tensor& mat1,
434432
const Tensor& mat2,
435433
const Tensor& result) {
436-
bool dtype_check = false;
437-
if constexpr (std::is_same_v<T, c10::BFloat16>) {
438434
#if defined(__aarch64__)
439-
if (mkldnn_bf16_device_check_arm()) {
440-
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
441-
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
442-
// inputs, allow it for float as well
443-
dtype_check = use_mkldnn_bf16_matmul() &&
444-
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
445-
}
446-
#else
447-
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
448-
(mat1.scalar_type() == kBFloat16);
435+
if (mkldnn_bf16_device_check_arm()) {
436+
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
437+
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
438+
// inputs, allow it for float as well
439+
return (
440+
use_mkldnn_bf16_matmul() &&
441+
(mat1.scalar_type() == mat2.scalar_type()) &&
442+
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
443+
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
444+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
445+
} else
449446
#endif
450-
} else if constexpr (std::is_same_v<T, c10::Half>) {
451-
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
452-
(mat1.scalar_type() == kHalf);
453-
} else if constexpr (std::is_same_v<T, float>) {
454-
dtype_check = dtype_check &&
455-
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
456-
(mat1.scalar_type() == kFloat);
447+
{
448+
return (
449+
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
450+
mat2.scalar_type() == kBFloat16 &&
451+
(!result.defined() || result.scalar_type() == kBFloat16) &&
452+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
457453
}
458-
if (!dtype_check) {
459-
return false;
460-
}
461-
bool size_check =
462-
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
463-
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
464-
(!result.defined() || result.scalar_type() == mat1.scalar_type());
465-
return dtype_check && size_check;
454+
}
455+
456+
bool use_mkldnn_fp16_matmul(
457+
const Tensor& mat1,
458+
const Tensor& mat2,
459+
const Tensor& result) {
460+
return (
461+
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
462+
mat2.scalar_type() == kHalf &&
463+
(!result.defined() || result.scalar_type() == kHalf) &&
464+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
465+
}
466+
467+
bool use_mkldnn_bf32_matmul(
468+
const Tensor& mat1,
469+
const Tensor& mat2,
470+
const Tensor& result) {
471+
return (
472+
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
473+
mat2.scalar_type() == kFloat &&
474+
(!result.defined() || result.scalar_type() == kFloat) &&
475+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
476+
}
477+
478+
bool use_mkldnn_tf32_matmul(
479+
const Tensor& mat1,
480+
const Tensor& mat2,
481+
const Tensor& result) {
482+
return (
483+
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
484+
mat2.scalar_type() == kFloat &&
485+
(!result.defined() || result.scalar_type() == kFloat) &&
486+
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
466487
}
467488

468489
bool use_mkldnn_matmul(
469490
const Tensor& mat1,
470491
const Tensor& mat2,
471492
const Tensor& result) {
472-
auto mat1_type = mat1.scalar_type();
473-
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
474-
return false;
475-
}
476-
AT_DISPATCH_FLOATING_TYPES_AND2(
477-
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
478-
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
479-
});
480-
return false;
493+
return (
494+
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
495+
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
496+
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
497+
use_mkldnn_tf32_matmul(mat1, mat2, result));
481498
}
482499

483500
static void _mkldnn_matmul_i8i8i32_with_primitive(

0 commit comments

Comments
 (0)