|
1 | 1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
2 | 2 | #include <ATen/Config.h>
|
3 | 3 | #include <ATen/Context.h>
|
4 |
| -#include <ATen/Dispatch.h> |
5 | 4 | #include <ATen/core/Tensor.h>
|
6 | 5 | #include <ATen/native/mkldnn/Matmul.h>
|
7 | 6 |
|
@@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
428 | 427 | }
|
429 | 428 | }
|
430 | 429 |
|
431 |
| -template <typename T> |
432 |
| -bool use_mkldnn_typed_matmul( |
| 430 | +bool use_mkldnn_bf16_matmul( |
433 | 431 | const Tensor& mat1,
|
434 | 432 | const Tensor& mat2,
|
435 | 433 | const Tensor& result) {
|
436 |
| - bool dtype_check = false; |
437 |
| - if constexpr (std::is_same_v<T, c10::BFloat16>) { |
438 | 434 | #if defined(__aarch64__)
|
439 |
| - if (mkldnn_bf16_device_check_arm()) { |
440 |
| - // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. |
441 |
| - // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 |
442 |
| - // inputs, allow it for float as well |
443 |
| - dtype_check = use_mkldnn_bf16_matmul() && |
444 |
| - ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)); |
445 |
| - } |
446 |
| -#else |
447 |
| - dtype_check = dtype_check && use_mkldnn_bf16_matmul() && |
448 |
| - (mat1.scalar_type() == kBFloat16); |
| 435 | + if (mkldnn_bf16_device_check_arm()) { |
| 436 | + // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. |
| 437 | + // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 |
| 438 | + // inputs, allow it for float as well |
| 439 | + return ( |
| 440 | + use_mkldnn_bf16_matmul() && |
| 441 | + (mat1.scalar_type() == mat2.scalar_type()) && |
| 442 | + (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && |
| 443 | + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && |
| 444 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 445 | + } else |
449 | 446 | #endif
|
450 |
| - } else if constexpr (std::is_same_v<T, c10::Half>) { |
451 |
| - dtype_check = dtype_check && use_mkldnn_fp16_matmul() && |
452 |
| - (mat1.scalar_type() == kHalf); |
453 |
| - } else if constexpr (std::is_same_v<T, float>) { |
454 |
| - dtype_check = dtype_check && |
455 |
| - (use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) && |
456 |
| - (mat1.scalar_type() == kFloat); |
| 447 | + { |
| 448 | + return ( |
| 449 | + use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 && |
| 450 | + mat2.scalar_type() == kBFloat16 && |
| 451 | + (!result.defined() || result.scalar_type() == kBFloat16) && |
| 452 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
457 | 453 | }
|
458 |
| - if (!dtype_check) { |
459 |
| - return false; |
460 |
| - } |
461 |
| - bool size_check = |
462 |
| - mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2); |
463 |
| - dtype_check = (mat1.scalar_type() == mat2.scalar_type()) && |
464 |
| - (!result.defined() || result.scalar_type() == mat1.scalar_type()); |
465 |
| - return dtype_check && size_check; |
| 454 | +} |
| 455 | + |
| 456 | +bool use_mkldnn_fp16_matmul( |
| 457 | + const Tensor& mat1, |
| 458 | + const Tensor& mat2, |
| 459 | + const Tensor& result) { |
| 460 | + return ( |
| 461 | + use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf && |
| 462 | + mat2.scalar_type() == kHalf && |
| 463 | + (!result.defined() || result.scalar_type() == kHalf) && |
| 464 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 465 | +} |
| 466 | + |
| 467 | +bool use_mkldnn_bf32_matmul( |
| 468 | + const Tensor& mat1, |
| 469 | + const Tensor& mat2, |
| 470 | + const Tensor& result) { |
| 471 | + return ( |
| 472 | + use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat && |
| 473 | + mat2.scalar_type() == kFloat && |
| 474 | + (!result.defined() || result.scalar_type() == kFloat) && |
| 475 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
| 476 | +} |
| 477 | + |
| 478 | +bool use_mkldnn_tf32_matmul( |
| 479 | + const Tensor& mat1, |
| 480 | + const Tensor& mat2, |
| 481 | + const Tensor& result) { |
| 482 | + return ( |
| 483 | + use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat && |
| 484 | + mat2.scalar_type() == kFloat && |
| 485 | + (!result.defined() || result.scalar_type() == kFloat) && |
| 486 | + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); |
466 | 487 | }
|
467 | 488 |
|
468 | 489 | bool use_mkldnn_matmul(
|
469 | 490 | const Tensor& mat1,
|
470 | 491 | const Tensor& mat2,
|
471 | 492 | const Tensor& result) {
|
472 |
| - auto mat1_type = mat1.scalar_type(); |
473 |
| - if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) { |
474 |
| - return false; |
475 |
| - } |
476 |
| - AT_DISPATCH_FLOATING_TYPES_AND2( |
477 |
| - kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] { |
478 |
| - return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result); |
479 |
| - }); |
480 |
| - return false; |
| 493 | + return ( |
| 494 | + use_mkldnn_bf16_matmul(mat1, mat2, result) || |
| 495 | + use_mkldnn_fp16_matmul(mat1, mat2, result) || |
| 496 | + use_mkldnn_bf32_matmul(mat1, mat2, result) || |
| 497 | + use_mkldnn_tf32_matmul(mat1, mat2, result)); |
481 | 498 | }
|
482 | 499 |
|
483 | 500 | static void _mkldnn_matmul_i8i8i32_with_primitive(
|
|
0 commit comments