From f18af1d82502df724e03c3da6f600825c714e17c Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Fri, 6 Jun 2025 22:36:18 +0400 Subject: [PATCH 1/8] sparse add and addmm --- CMakeLists.txt | 1 + aten/src/ATen/CMakeLists.txt | 5 + aten/src/ATen/Config.h.in | 1 + aten/src/ATen/Context.cpp | 8 + aten/src/ATen/Context.h | 5 + .../src/ATen/native/sparse/SparseBlasImpl.cpp | 19 +- .../native/sparse/SparseCsrTensorMath.cpp | 18 +- .../native/sparse/eigen/SparseBlasImpl.cpp | 432 ++++++++++++++++++ .../ATen/native/sparse/eigen/SparseBlasImpl.h | 29 ++ cmake/Dependencies.cmake | 10 + cmake/Summary.cmake | 1 + torch/csrc/Module.cpp | 2 + 12 files changed, 523 insertions(+), 8 deletions(-) create mode 100644 aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp create mode 100644 aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h diff --git a/CMakeLists.txt b/CMakeLists.txt index cc9476bb001ae..5a2ad89d77eb8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,7 @@ option(USE_PRECOMPILED_HEADERS "Use pre-compiled headers to accelerate build." option(USE_PROF "Use profiling" OFF) option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON) option(USE_SNPE "Use Qualcomm's SNPE library" OFF) +option(USE_EIGEN_SPARSE "Use Eigen Sparse Matrices" OFF) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) cmake_dependent_option( diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 5f4997357f826..0f083a582404c 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -96,6 +96,8 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") file(GLOB vulkan_cpp "vulkan/*.cpp") file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp") +file(GLOB native_eigen_cpp "native/sparse/eigen/*.cpp") + # Metal file(GLOB metal_h "metal/*.h") file(GLOB metal_cpp "metal/*.cpp") @@ -341,6 +343,9 @@ if(USE_VULKAN) else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +if(USE_EIGEN_SPARSE) + set(all_cpu_cpp ${all_cpu_cpp} ${native_eigen_cpp}) +endif() if(USE_MTIA) set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} ${mtia_cpp} ${mtia_h} ${native_mtia_cpp} ${native_mtia_h}) diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index c22e15a52aa23..0bae6d4af6e5e 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -20,3 +20,4 @@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ #define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@ +#define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@ diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 30c2235131fb6..4d48084b0ab89 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -698,6 +698,14 @@ bool Context::hasLAPACK() { #endif } +bool Context::hasEigenSparse() { +#if AT_USE_EIGEN_SPARSE() + return true; +#else + return false; +#endif +} + at::QEngine Context::qEngine() const { static auto _quantized_engine = []() { at::QEngine qengine = at::kNoQEngine; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 2cc12a38a0b6e..5cfa9b23e20aa 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -133,6 +133,7 @@ class TORCH_API Context { static bool hasLAPACK(); static bool hasMKLDNN(); static bool ckSupported(); + static bool hasEigenSparse(); static bool hasMAGMA() { return detail::getCUDAHooks().hasMAGMA(); } @@ -615,6 +616,10 @@ inline bool hasLAPACK() { return globalContext().hasLAPACK(); } +inline bool hasEigenSparse() { + return globalContext().hasEigenSparse(); +} + inline bool hasMAGMA() { return globalContext().hasMAGMA(); } diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index 5a3f5f14dc0a7..c841da8354b5f 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -23,6 +23,9 @@ #include #endif +#if AT_USE_EIGEN_SPARSE() +#include +#endif namespace at::native::sparse::impl { @@ -442,13 +445,15 @@ void add_out_sparse_csr( const Tensor& mat2, const Scalar& alpha, const Tensor& result) { -#if !AT_MKL_ENABLED() - TORCH_CHECK( - false, - "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ", - "Please use PyTorch built MKL support."); -#else +#if AT_USE_MKL_SPARSE() sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result); +#elif AT_USE_EIGEN_SPARSE() + sparse::impl::eigen::add_out_sparse(mat1, mat2, alpha, result); +#else + TORCH_CHECK( + false, + "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ", + "Please use PyTorch built MKL support."); #endif } @@ -459,7 +464,7 @@ void triangular_solve_out_sparse_csr( bool upper, bool transpose, bool unitriangular) { -#if !AT_MKL_ENABLED() +#if !AT_USE_MKL_SPARSE() TORCH_CHECK( false, "Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ", diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index ba94f98551747..6d7044e7dc92a 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -127,6 +127,10 @@ #include #endif +#if AT_USE_EIGEN_SPARSE() +#include +#endif + #include namespace at { @@ -536,7 +540,12 @@ static void addmm_out_sparse_csr_native_cpu( auto values = sparse.values(); scalar_t cast_alpha = alpha.to(); - r.mul_(beta); + // If beta is zero NaN and Inf should not be propagated to the result + if (beta.toComplexDouble() == 0.) { + r.zero_(); + } else { + r.mul_(beta); + } AT_DISPATCH_INDEX_TYPES( col_indices.scalar_type(), "csr_mm_crow_indices", [&]() { auto csr_accessor = csr.accessor(); @@ -648,6 +657,13 @@ Tensor& addmm_out_sparse_compressed_cpu( return result; } +#if AT_USE_EIGEN_SPARSE() + if (result.layout() != kStrided && mat1.layout() != kStrided && mat2.layout() != kStrided) { + sparse::impl::eigen::addmm_out_sparse(mat1, mat2, result, alpha, beta); + return result; + } +#endif + #if !AT_USE_MKL_SPARSE() // The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @ // strided -> strided diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp new file mode 100644 index 0000000000000..5cd0ba648ddcb --- /dev/null +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -0,0 +1,432 @@ +#include + +#if AT_USE_EIGEN_SPARSE() + +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +#include + +namespace at::native::sparse::impl::eigen { + +namespace { + +void inline sparse_indices_to_result_dtype_inplace( + const c10::ScalarType& dtype, + const at::Tensor& input) { + if (input.layout() == kSparseCsr) { + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + input.crow_indices().to(dtype), + input.col_indices().to(dtype), + input.values(), + input.sizes()); + } else if (input.layout() == kSparseCsc) { + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + input.ccol_indices().to(dtype), + input.row_indices().to(dtype), + input.values(), + input.sizes()); + } else { + TORCH_CHECK( + false, + "Eigen: expected tensor be kSparseCsr or kSparseCsc, but got", + input.layout()); + } +} + +void inline sparse_indices_and_values_resize( + const at::Tensor& input, + int64_t nnz) { + if (input.layout() == kSparseCsr) { + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + input.crow_indices(), + input.col_indices().resize_({nnz}), + input.values().resize_({nnz}), + input.sizes()); + } else if (input.layout() == kSparseCsc) { + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + input.ccol_indices(), + input.row_indices().resize_({nnz}), + input.values().resize_({nnz}), + input.sizes()); + } else { + TORCH_CHECK( + false, + "Eigen: expected tensor be kSparseCsr or kSparseCsc, but got", + input.layout()); + } +} + +template +const Eigen::Map> +Tensor_to_EigenCsc(const at::Tensor& tensor) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t nnz = tensor._nnz(); + + TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors"); + + index_t* ccol_indices_ptr = tensor.ccol_indices().data_ptr(); + index_t* row_indices_ptr = tensor.row_indices().data_ptr(); + scalar_t* values_ptr = tensor.values().data_ptr(); + Eigen::Map> map( + rows, cols, nnz, ccol_indices_ptr, row_indices_ptr, values_ptr); + return map; +} + +template +const Eigen::Map> +Tensor_to_EigenCsr(const at::Tensor& tensor) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t nnz = tensor._nnz(); + + TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors"); + + index_t* crow_indices_ptr = tensor.crow_indices().data_ptr(); + index_t* col_indices_ptr = tensor.col_indices().data_ptr(); + scalar_t* values_ptr = tensor.values().data_ptr(); + Eigen::Map> map( + rows, cols, nnz, crow_indices_ptr, col_indices_ptr, values_ptr); + return map; +} + +template +void EigenCsr_to_Tensor( + const at::Tensor& tensor, + const Eigen::SparseMatrix& matrix) { + TORCH_CHECK( + tensor.layout() == kSparseCsr, + "EigenCsr_to_Tensor, expected tensor be kSparseCsr, but got", + tensor.layout()); + + int64_t nnz = matrix.nonZeros(); + int64_t rows = matrix.outerSize(); + sparse_indices_and_values_resize(tensor, nnz); + + if (nnz > 0) { + std::memcpy( + tensor.values().mutable_data_ptr(), + matrix.valuePtr(), + nnz * sizeof(scalar_t)); + std::memcpy( + tensor.col_indices().mutable_data_ptr(), + matrix.innerIndexPtr(), + nnz * sizeof(index_t)); + } + if (rows > 0) { + std::memcpy( + tensor.crow_indices().mutable_data_ptr(), + matrix.outerIndexPtr(), + rows * sizeof(index_t)); + } + tensor.crow_indices().mutable_data_ptr()[rows] = nnz; +} + +template +void EigenCsc_to_Tensor( + const at::Tensor& tensor, + const Eigen::SparseMatrix& matrix) { + TORCH_CHECK( + tensor.layout() == kSparseCsc, + "EigenCsr_to_Tensor, expected tensor be kSparseCsc, but got", + tensor.layout()); + + int64_t nnz = matrix.nonZeros(); + int64_t rows = matrix.outerSize(); + sparse_indices_and_values_resize(tensor, nnz); + + if (nnz > 0) { + std::memcpy( + tensor.values().mutable_data_ptr(), + matrix.valuePtr(), + nnz * sizeof(scalar_t)); + std::memcpy( + tensor.row_indices().mutable_data_ptr(), + matrix.innerIndexPtr(), + nnz * sizeof(index_t)); + } + if (rows > 0) { + std::memcpy( + tensor.ccol_indices().mutable_data_ptr(), + matrix.outerIndexPtr(), + rows * sizeof(index_t)); + } + tensor.ccol_indices().mutable_data_ptr()[rows] = nnz; +} + +template +void add_out_sparse_eigen( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Scalar& alpha, + const at::Tensor& result) { + // empty matrices + if (mat1._nnz() == 0 && mat2._nnz() == 0) { + return; + } + + if (mat2._nnz() == 0 || alpha.toComplexDouble() == 0.) { + sparse_indices_and_values_resize(result, mat1._nnz()); + result.copy_(mat1); + return; + } else if (mat1._nnz() == 0) { + sparse_indices_and_values_resize(result, mat2._nnz()); + result.copy_(mat2); + result.values().mul_(alpha); + return; + } + + c10::ScalarType result_index_dtype; + + if(result.layout() == kSparseCsr) { + result_index_dtype = result.col_indices().scalar_type(); + } else if (result.layout() == kSparseCsc) { + result_index_dtype = result.row_indices().scalar_type(); + } + + sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1); + sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2); + + AT_DISPATCH_INDEX_TYPES( + result_index_dtype, "eigen_sparse_add", [&]() { + scalar_t _alpha = alpha.to(); + typedef Eigen::SparseMatrix + EigenCscMatrix; + typedef Eigen::SparseMatrix + EigenCsrMatrix; + + if(result.layout() == kSparseCsr) { + const Eigen::Map mat1_eigen = + Tensor_to_EigenCsr(mat1); + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsr(mat2); + const EigenCsrMatrix mat1_mat2_eigen = + (mat1_eigen + _alpha * mat2_eigen); + + EigenCsr_to_Tensor(result, mat1_mat2_eigen); + } else if (mat1.layout() == kSparseCsc) { + const Eigen::Map mat1_eigen = + Tensor_to_EigenCsc(mat1); + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsc(mat2); + const EigenCscMatrix mat1_mat2_eigen = + (mat1_eigen + _alpha * mat2_eigen); + + EigenCsc_to_Tensor(result, mat1_mat2_eigen); + } + }); +} + +template +void addmm_out_sparse_eigen( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& result, + const at::Scalar& alpha, + const at::Scalar& beta) { + // empty matrices + if (mat1._nnz() == 0 || mat2._nnz() == 0) { + return; + } + + // If beta is zero NaN and Inf should not be propagated to the result + if (beta.toComplexDouble() == 0.) { + result.values().zero_(); + } else { + result.values().mul_(beta); + } + + c10::ScalarType result_index_dtype; + + if(result.layout() == kSparseCsr) { + result_index_dtype = result.col_indices().scalar_type(); + } else if (result.layout() == kSparseCsc) { + result_index_dtype = result.row_indices().scalar_type(); + } + + sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1); + sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2); + + AT_DISPATCH_INDEX_TYPES( + result_index_dtype, "eigen_sparse_mm", [&]() { + typedef Eigen::SparseMatrix + EigenCscMatrix; + typedef Eigen::SparseMatrix + EigenCsrMatrix; + + if(result.layout() == kSparseCsr) { + result.crow_indices().mutable_data_ptr()[0] = (index_t) 0; + std::cout<()[0] = (index_t) 0; + std::cout< mat1_eigen = + Tensor_to_EigenCsr(mat1); + if (mat2.layout() == kSparseCsr) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsr(mat2); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } else if (mat2.layout() == kSparseCsc) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsc(mat2); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } + } else if (mat1.layout() == kSparseCsc) { + const Eigen::Map mat1_eigen = + Tensor_to_EigenCsc(mat1); + if (mat2.layout() == kSparseCsc) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsc(mat2); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } else if (mat2.layout() == kSparseCsr) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsr(mat2); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } + } + } else if (mat1_mat2.layout() == kSparseCsc) { + if (mat1.layout() == kSparseCsr) { + const Eigen::Map mat1_eigen = + Tensor_to_EigenCsr(mat1); + if (mat2.layout() == kSparseCsr) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsr(mat2); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } else if (mat2.layout() == kSparseCsc) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsc(mat2); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } + } else if (mat1.layout() == kSparseCsc) { + const Eigen::Map mat1_eigen = + Tensor_to_EigenCsc(mat1); + if (mat2.layout() == kSparseCsc) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsc(mat2); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } else if (mat2.layout() == kSparseCsr) { + const Eigen::Map mat2_eigen = + Tensor_to_EigenCsr(mat2); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + } + } + } + + result.add_(mat1_mat2, alpha.to()); + }); +} + +} // anonymus namespace + +void addmm_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& result, + const at::Scalar& alpha, + const at::Scalar& beta) { + TORCH_CHECK( + result.layout() != kStrided && mat1.layout() != kStrided && mat2.layout() != kStrided, + "eigen::addmm_out_sparse: computation on CPU is not implemented for ", + result.layout(), + " + ", + mat1.layout(), + " @ ", + mat2.layout()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + result.scalar_type(), "addmm_out_sparse_eigen", [&] { + addmm_out_sparse_eigen(mat1, mat2, result, alpha, beta); + }); +} + +void add_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Scalar& alpha, + const at::Tensor& result) { + TORCH_CHECK( + (result.layout() == kSparseCsr && mat1.layout() == kSparseCsr && mat2.layout() == kSparseCsr) || + (result.layout() == kSparseCsc && mat1.layout() == kSparseCsc && mat2.layout() == kSparseCsc), + "eigen::add_out_sparse: computation on CPU is not implemented for ", + mat1.layout(), + " + ", + mat2.layout(), + " -> ", + result.layout()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + result.scalar_type(), "add_out_sparse_eigen", [&] { + add_out_sparse_eigen(mat1, mat2, alpha, result); + }); +} + +} // namespace at::native::sparse::impl::eigen + +#else + +namespace at::native::sparse::impl::eigen { + +void addmm_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& result, + const at::Scalar& alpha, + const at::Scalar& beta) { + TORCH_CHECK( + false, + "eigen::addmm_out_sparse: Eigen was not enabled ", + result.layout(), + " + ", + mat1.layout(), + " @ ", + mat2.layout()); +} + +void add_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Scalar& alpha, + const at::Tensor& result) { + TORCH_CHECK( + false, + "eigen::add_out_sparse: Eigen was not enabled ", + mat1.layout(), + " + ", + mat2.layout(), + " -> ", + result.layout()); +} + +} // namespace at::native::sparse::impl::eigen + +#endif // AT_USE_EIGEN_SPARSE() diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h new file mode 100644 index 0000000000000..d8e8dc322bc37 --- /dev/null +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#if AT_USE_EIGEN_SPARSE() +#ifndef EIGEN_MPL2_ONLY +#define EIGEN_MPL2_ONLY +#endif + +#include + +namespace at::native::sparse::impl::eigen { + +void addmm_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& result, + const at::Scalar& alpha, + const at::Scalar& beta); + +void add_out_sparse( + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Scalar& alpha, + const at::Tensor& result); + +} // namespace at::native::sparse::impl::eigen + +#endif diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 26d882f2f7f18..944c7821f6676 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -153,6 +153,7 @@ set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) set(AT_KLEIDIAI_ENABLED 0) +set(AT_USE_EIGEN_SPARSE 0) # setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") @@ -262,6 +263,15 @@ if(BLAS_LIBRARIES AND BLAS_CHECK_F2C) include(cmake/BLAS_ABI.cmake) endif() +if(USE_EIGEN_SPARSE AND BLAS_INFO STREQUAL "mkl") + message(WARNING "Disabling USE_EIGEN_SPARSE because MKL is enabled") + set(USE_EIGEN_SPARSE OFF) +endif() + +if(USE_EIGEN_SPARSE) + set(AT_USE_EIGEN_SPARSE 1) +endif() + if(NOT INTERN_BUILD_MOBILE) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 63e501bcb5aba..745d9ea058687 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -135,6 +135,7 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") + message(STATUS " USE_EIGEN_FOR_SPARSE : ${USE_EIGEN_SPARSE}") message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") message(STATUS " USE_KINETO : ${USE_KINETO}") message(STATUS " USE_GFLAGS : ${USE_GFLAGS}") diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 1aa8a8b6df8a8..c78e1f665395f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2203,6 +2203,8 @@ Call this whenever a new thread is created in order to propagate values from set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False)); ASSERT_TRUE( set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); + ASSERT_TRUE( + set_module_attr("_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False)); py_module.def("_valgrind_supported_platform", []() { #if defined(USE_VALGRIND) From e0a0a15c566fb6b23dd87a88e8da18c59b7a3dc9 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Sat, 7 Jun 2025 02:28:24 +0400 Subject: [PATCH 2/8] fix lint and bazel --- BUILD.bazel | 1 + torch/csrc/Module.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 50ffa12576475..58ebc31e243c4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -279,6 +279,7 @@ header_template_rule( "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", "@AT_KLEIDIAI_ENABLED@": "0", + "@AT_USE_EIGEN_SPARSE@": "0", }, ) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index c78e1f665395f..d78b4ef40e30b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2203,8 +2203,8 @@ Call this whenever a new thread is created in order to propagate values from set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False)); ASSERT_TRUE( set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); - ASSERT_TRUE( - set_module_attr("_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False)); + ASSERT_TRUE(set_module_attr( + "_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False)); py_module.def("_valgrind_supported_platform", []() { #if defined(USE_VALGRIND) From 807eb71e0c5bb488d20e098dbf52182af2ee3d67 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Sat, 7 Jun 2025 22:01:47 +0400 Subject: [PATCH 3/8] clean up and enable fastpath --- .../native/sparse/eigen/SparseBlasImpl.cpp | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index 5cd0ba648ddcb..ed5bc4fe3804a 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -2,8 +2,6 @@ #if AT_USE_EIGEN_SPARSE() -#include - #include #include #include @@ -194,7 +192,7 @@ void add_out_sparse_eigen( c10::ScalarType result_index_dtype; - if(result.layout() == kSparseCsr) { + if (result.layout() == kSparseCsr) { result_index_dtype = result.col_indices().scalar_type(); } else if (result.layout() == kSparseCsc) { result_index_dtype = result.row_indices().scalar_type(); @@ -211,7 +209,7 @@ void add_out_sparse_eigen( typedef Eigen::SparseMatrix EigenCsrMatrix; - if(result.layout() == kSparseCsr) { + if (result.layout() == kSparseCsr) { const Eigen::Map mat1_eigen = Tensor_to_EigenCsr(mat1); const Eigen::Map mat2_eigen = @@ -246,15 +244,17 @@ void addmm_out_sparse_eigen( } // If beta is zero NaN and Inf should not be propagated to the result + // In addition, beta = 0 lets us enable a fast-path for result = alpha * A @ B + bool is_beta_zero = false; if (beta.toComplexDouble() == 0.) { + is_beta_zero = true; result.values().zero_(); } else { result.values().mul_(beta); } c10::ScalarType result_index_dtype; - - if(result.layout() == kSparseCsr) { + if (result.layout() == kSparseCsr) { result_index_dtype = result.col_indices().scalar_type(); } else if (result.layout() == kSparseCsc) { result_index_dtype = result.row_indices().scalar_type(); @@ -270,17 +270,14 @@ void addmm_out_sparse_eigen( typedef Eigen::SparseMatrix EigenCsrMatrix; - if(result.layout() == kSparseCsr) { - result.crow_indices().mutable_data_ptr()[0] = (index_t) 0; - std::cout<()[0] = (index_t) 0; - std::cout< mat1_eigen = Tensor_to_EigenCsr(mat1); @@ -342,7 +339,11 @@ void addmm_out_sparse_eigen( } } - result.add_(mat1_mat2, alpha.to()); + if (is_beta_zero) { + result.mul_(alpha.to()); + } else { + result.add_(mat1_mat2, alpha.to()); + } }); } @@ -404,7 +405,7 @@ void addmm_out_sparse( const at::Scalar& beta) { TORCH_CHECK( false, - "eigen::addmm_out_sparse: Eigen was not enabled ", + "eigen::addmm_out_sparse: Eigen was not enabled for ", result.layout(), " + ", mat1.layout(), @@ -419,7 +420,7 @@ void add_out_sparse( const at::Tensor& result) { TORCH_CHECK( false, - "eigen::add_out_sparse: Eigen was not enabled ", + "eigen::add_out_sparse: Eigen was not enabled for ", mat1.layout(), " + ", mat2.layout(), From ff0ab1c0c9d7c0babeac0d0e7fdb84fd0b255486 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 14 Jul 2025 12:51:09 +0400 Subject: [PATCH 4/8] apply suggestions --- .../native/sparse/SparseCsrTensorMath.cpp | 4 +- .../native/sparse/eigen/SparseBlasImpl.cpp | 58 +++++-------------- 2 files changed, 19 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index 6d7044e7dc92a..4faa135713d65 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -658,7 +658,9 @@ Tensor& addmm_out_sparse_compressed_cpu( } #if AT_USE_EIGEN_SPARSE() - if (result.layout() != kStrided && mat1.layout() != kStrided && mat2.layout() != kStrided) { + if ((result.layout() == kSparseCsr || result.layout() == kSparseCsc) && + (mat1.layout() == kSparseCsr || mat1.layout() == kSparseCsc) && + (mat2.layout() == kSparseCsr || mat2.layout() == kSparseCsc)) { sparse::impl::eigen::addmm_out_sparse(mat1, mat2, result, alpha, beta); return result; } diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index ed5bc4fe3804a..a255e3511b1b5 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -30,18 +30,13 @@ void inline sparse_indices_to_result_dtype_inplace( input.col_indices().to(dtype), input.values(), input.sizes()); - } else if (input.layout() == kSparseCsc) { + } else { static_cast(input.unsafeGetTensorImpl()) ->set_member_tensors( input.ccol_indices().to(dtype), input.row_indices().to(dtype), input.values(), input.sizes()); - } else { - TORCH_CHECK( - false, - "Eigen: expected tensor be kSparseCsr or kSparseCsc, but got", - input.layout()); } } @@ -55,18 +50,13 @@ void inline sparse_indices_and_values_resize( input.col_indices().resize_({nnz}), input.values().resize_({nnz}), input.sizes()); - } else if (input.layout() == kSparseCsc) { + } else { static_cast(input.unsafeGetTensorImpl()) ->set_member_tensors( input.ccol_indices(), input.row_indices().resize_({nnz}), input.values().resize_({nnz}), input.sizes()); - } else { - TORCH_CHECK( - false, - "Eigen: expected tensor be kSparseCsr or kSparseCsc, but got", - input.layout()); } } @@ -146,7 +136,7 @@ void EigenCsc_to_Tensor( tensor.layout()); int64_t nnz = matrix.nonZeros(); - int64_t rows = matrix.outerSize(); + int64_t cols = matrix.outerSize(); sparse_indices_and_values_resize(tensor, nnz); if (nnz > 0) { @@ -159,13 +149,13 @@ void EigenCsc_to_Tensor( matrix.innerIndexPtr(), nnz * sizeof(index_t)); } - if (rows > 0) { + if (cols > 0) { std::memcpy( tensor.ccol_indices().mutable_data_ptr(), matrix.outerIndexPtr(), - rows * sizeof(index_t)); + cols * sizeof(index_t)); } - tensor.ccol_indices().mutable_data_ptr()[rows] = nnz; + tensor.ccol_indices().mutable_data_ptr()[cols] = nnz; } template @@ -190,13 +180,7 @@ void add_out_sparse_eigen( return; } - c10::ScalarType result_index_dtype; - - if (result.layout() == kSparseCsr) { - result_index_dtype = result.col_indices().scalar_type(); - } else if (result.layout() == kSparseCsc) { - result_index_dtype = result.row_indices().scalar_type(); - } + c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result); sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1); sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2); @@ -218,7 +202,7 @@ void add_out_sparse_eigen( (mat1_eigen + _alpha * mat2_eigen); EigenCsr_to_Tensor(result, mat1_mat2_eigen); - } else if (mat1.layout() == kSparseCsc) { + } else { const Eigen::Map mat1_eigen = Tensor_to_EigenCsc(mat1); const Eigen::Map mat2_eigen = @@ -253,12 +237,7 @@ void addmm_out_sparse_eigen( result.values().mul_(beta); } - c10::ScalarType result_index_dtype; - if (result.layout() == kSparseCsr) { - result_index_dtype = result.col_indices().scalar_type(); - } else if (result.layout() == kSparseCsc) { - result_index_dtype = result.row_indices().scalar_type(); - } + c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result); sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1); sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2); @@ -292,7 +271,7 @@ void addmm_out_sparse_eigen( const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); } - } else if (mat1.layout() == kSparseCsc) { + } else { const Eigen::Map mat1_eigen = Tensor_to_EigenCsc(mat1); if (mat2.layout() == kSparseCsc) { @@ -307,7 +286,7 @@ void addmm_out_sparse_eigen( EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } - } else if (mat1_mat2.layout() == kSparseCsc) { + } else { if (mat1.layout() == kSparseCsr) { const Eigen::Map mat1_eigen = Tensor_to_EigenCsr(mat1); @@ -322,7 +301,7 @@ void addmm_out_sparse_eigen( const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); } - } else if (mat1.layout() == kSparseCsc) { + } else { const Eigen::Map mat1_eigen = Tensor_to_EigenCsc(mat1); if (mat2.layout() == kSparseCsc) { @@ -355,14 +334,9 @@ void addmm_out_sparse( const at::Tensor& result, const at::Scalar& alpha, const at::Scalar& beta) { - TORCH_CHECK( - result.layout() != kStrided && mat1.layout() != kStrided && mat2.layout() != kStrided, - "eigen::addmm_out_sparse: computation on CPU is not implemented for ", - result.layout(), - " + ", - mat1.layout(), - " @ ", - mat2.layout()); + AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat1.layout(), "eigen::addmm_out_sparse:mat1", [&]{}); + AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat2.layout(), "eigen::addmm_out_sparse:mat2", [&]{}); + AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(result.layout(), "eigen::addmm_out_sparse:result", [&]{}); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( result.scalar_type(), "addmm_out_sparse_eigen", [&] { @@ -378,7 +352,7 @@ void add_out_sparse( TORCH_CHECK( (result.layout() == kSparseCsr && mat1.layout() == kSparseCsr && mat2.layout() == kSparseCsr) || (result.layout() == kSparseCsc && mat1.layout() == kSparseCsc && mat2.layout() == kSparseCsc), - "eigen::add_out_sparse: computation on CPU is not implemented for ", + "eigen::add_out_sparse: expected the same layout for all operands but got ", mat1.layout(), " + ", mat2.layout(), From 9d1eb802d915f8dcf85212d057cb9ba1d5c58a29 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 14 Jul 2025 12:55:24 +0400 Subject: [PATCH 5/8] missed a couple of suggestions --- aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index a255e3511b1b5..4bf670f269534 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -265,7 +265,7 @@ void addmm_out_sparse_eigen( Tensor_to_EigenCsr(mat2); const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); - } else if (mat2.layout() == kSparseCsc) { + } else { const Eigen::Map mat2_eigen = Tensor_to_EigenCsc(mat2); const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); @@ -279,7 +279,7 @@ void addmm_out_sparse_eigen( Tensor_to_EigenCsc(mat2); const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); - } else if (mat2.layout() == kSparseCsr) { + } else { const Eigen::Map mat2_eigen = Tensor_to_EigenCsr(mat2); const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); @@ -295,7 +295,7 @@ void addmm_out_sparse_eigen( Tensor_to_EigenCsr(mat2); const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); - } else if (mat2.layout() == kSparseCsc) { + } else { const Eigen::Map mat2_eigen = Tensor_to_EigenCsc(mat2); const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); @@ -309,7 +309,7 @@ void addmm_out_sparse_eigen( Tensor_to_EigenCsc(mat2); const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); - } else if (mat2.layout() == kSparseCsr) { + } else { const Eigen::Map mat2_eigen = Tensor_to_EigenCsr(mat2); const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); From 99f56da726e83e734d022e8b1544f496a06647b1 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 7 Aug 2025 18:10:37 +0400 Subject: [PATCH 6/8] apply review suggestions --- .../native/sparse/eigen/SparseBlasImpl.cpp | 257 ++++++------------ 1 file changed, 88 insertions(+), 169 deletions(-) diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index 4bf670f269534..a1da733aa110c 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -23,139 +23,75 @@ namespace { void inline sparse_indices_to_result_dtype_inplace( const c10::ScalarType& dtype, const at::Tensor& input) { - if (input.layout() == kSparseCsr) { - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - input.crow_indices().to(dtype), - input.col_indices().to(dtype), - input.values(), - input.sizes()); - } else { - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - input.ccol_indices().to(dtype), - input.row_indices().to(dtype), - input.values(), - input.sizes()); - } + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(input); + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices.to(dtype), + plain_indices.to(dtype), + input.values(), + input.sizes()); } void inline sparse_indices_and_values_resize( const at::Tensor& input, int64_t nnz) { - if (input.layout() == kSparseCsr) { - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - input.crow_indices(), - input.col_indices().resize_({nnz}), - input.values().resize_({nnz}), - input.sizes()); - } else { - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - input.ccol_indices(), - input.row_indices().resize_({nnz}), - input.values().resize_({nnz}), - input.sizes()); - } + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(input); + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices, + plain_indices.resize_({nnz}), + input.values().resize_({nnz}), + input.sizes()); } -template -const Eigen::Map> -Tensor_to_EigenCsc(const at::Tensor& tensor) { +template +const Eigen::Map> +Tensor_to_Eigen(const at::Tensor& tensor) { int64_t rows = tensor.size(0); int64_t cols = tensor.size(1); int64_t nnz = tensor._nnz(); - - TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors"); - - index_t* ccol_indices_ptr = tensor.ccol_indices().data_ptr(); - index_t* row_indices_ptr = tensor.row_indices().data_ptr(); + TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensor values"); + auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor); + index_t* c_indices_ptr = compressed_indices.data_ptr(); + index_t* p_indices_ptr = plain_indices.data_ptr(); scalar_t* values_ptr = tensor.values().data_ptr(); - Eigen::Map> map( - rows, cols, nnz, ccol_indices_ptr, row_indices_ptr, values_ptr); + Eigen::Map> map( + rows, cols, nnz, c_indices_ptr, p_indices_ptr, values_ptr); return map; } -template -const Eigen::Map> -Tensor_to_EigenCsr(const at::Tensor& tensor) { - int64_t rows = tensor.size(0); - int64_t cols = tensor.size(1); - int64_t nnz = tensor._nnz(); - - TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors"); - - index_t* crow_indices_ptr = tensor.crow_indices().data_ptr(); - index_t* col_indices_ptr = tensor.col_indices().data_ptr(); - scalar_t* values_ptr = tensor.values().data_ptr(); - Eigen::Map> map( - rows, cols, nnz, crow_indices_ptr, col_indices_ptr, values_ptr); - return map; -} - -template -void EigenCsr_to_Tensor( +template +void Eigen_to_Tensor( const at::Tensor& tensor, - const Eigen::SparseMatrix& matrix) { + const Eigen::SparseMatrix& matrix) { + const Layout eigen_layout = (eigen_options == Eigen::RowMajor ? kSparseCsr : kSparseCsc); TORCH_CHECK( - tensor.layout() == kSparseCsr, - "EigenCsr_to_Tensor, expected tensor be kSparseCsr, but got", + tensor.layout() == eigen_layout, + "Eigen_to_Tensor, expected tensor be ", eigen_layout, ", but got ", tensor.layout()); - int64_t nnz = matrix.nonZeros(); - int64_t rows = matrix.outerSize(); + int64_t csize = matrix.outerSize(); sparse_indices_and_values_resize(tensor, nnz); - + auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor); if (nnz > 0) { std::memcpy( tensor.values().mutable_data_ptr(), matrix.valuePtr(), nnz * sizeof(scalar_t)); std::memcpy( - tensor.col_indices().mutable_data_ptr(), + plain_indices.mutable_data_ptr(), matrix.innerIndexPtr(), nnz * sizeof(index_t)); } - if (rows > 0) { + if (csize > 0) { std::memcpy( - tensor.crow_indices().mutable_data_ptr(), + compressed_indices.mutable_data_ptr(), matrix.outerIndexPtr(), - rows * sizeof(index_t)); + csize * sizeof(index_t)); } - tensor.crow_indices().mutable_data_ptr()[rows] = nnz; -} - -template -void EigenCsc_to_Tensor( - const at::Tensor& tensor, - const Eigen::SparseMatrix& matrix) { - TORCH_CHECK( - tensor.layout() == kSparseCsc, - "EigenCsr_to_Tensor, expected tensor be kSparseCsc, but got", - tensor.layout()); - - int64_t nnz = matrix.nonZeros(); - int64_t cols = matrix.outerSize(); - sparse_indices_and_values_resize(tensor, nnz); - - if (nnz > 0) { - std::memcpy( - tensor.values().mutable_data_ptr(), - matrix.valuePtr(), - nnz * sizeof(scalar_t)); - std::memcpy( - tensor.row_indices().mutable_data_ptr(), - matrix.innerIndexPtr(), - nnz * sizeof(index_t)); - } - if (cols > 0) { - std::memcpy( - tensor.ccol_indices().mutable_data_ptr(), - matrix.outerIndexPtr(), - cols * sizeof(index_t)); - } - tensor.ccol_indices().mutable_data_ptr()[cols] = nnz; + compressed_indices.mutable_data_ptr()[csize] = nnz; } template @@ -188,29 +124,17 @@ void add_out_sparse_eigen( AT_DISPATCH_INDEX_TYPES( result_index_dtype, "eigen_sparse_add", [&]() { scalar_t _alpha = alpha.to(); - typedef Eigen::SparseMatrix - EigenCscMatrix; - typedef Eigen::SparseMatrix - EigenCsrMatrix; if (result.layout() == kSparseCsr) { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsr(mat1); - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsr(mat2); - const EigenCsrMatrix mat1_mat2_eigen = - (mat1_eigen + _alpha * mat2_eigen); - - EigenCsr_to_Tensor(result, mat1_mat2_eigen); + auto mat1_eigen = Tensor_to_Eigen(mat1); + auto mat2_eigen = Tensor_to_Eigen(mat2); + auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen); + Eigen_to_Tensor(result, mat1_mat2_eigen); } else { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsc(mat1); - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsc(mat2); - const EigenCscMatrix mat1_mat2_eigen = - (mat1_eigen + _alpha * mat2_eigen); - - EigenCsc_to_Tensor(result, mat1_mat2_eigen); + auto mat1_eigen = Tensor_to_Eigen(mat1); + auto mat2_eigen = Tensor_to_Eigen(mat2); + auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen); + Eigen_to_Tensor(result, mat1_mat2_eigen); } }); } @@ -244,11 +168,6 @@ void addmm_out_sparse_eigen( AT_DISPATCH_INDEX_TYPES( result_index_dtype, "eigen_sparse_mm", [&]() { - typedef Eigen::SparseMatrix - EigenCscMatrix; - typedef Eigen::SparseMatrix - EigenCsrMatrix; - at::Tensor mat1_mat2; if (is_beta_zero) { mat1_mat2 = result; @@ -258,62 +177,62 @@ void addmm_out_sparse_eigen( if (mat1_mat2.layout() == kSparseCsr) { if (mat1.layout() == kSparseCsr) { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsr(mat1); + const auto mat1_eigen = Tensor_to_Eigen(mat1); if (mat2.layout() == kSparseCsr) { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsr(mat2); - const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csr = M1_csr * M2_csr + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsc(mat2); - const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csr = M1_csr * M2_csc + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } else { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsc(mat1); - if (mat2.layout() == kSparseCsc) { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsc(mat2); - const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + const auto mat1_eigen = Tensor_to_Eigen(mat1); + if (mat2.layout() == kSparseCsr) { + // Out_csr = M1_csc * M2_csr + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsr(mat2); - const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsr_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csr = M1_csc * M2_csc + // This multiplication will be computationally inefficient, as it will require + // additional conversion of the output matrix from CSC to CSR format. + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } } else { if (mat1.layout() == kSparseCsr) { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsr(mat1); + const auto mat1_eigen = Tensor_to_Eigen(mat1); if (mat2.layout() == kSparseCsr) { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsr(mat2); - const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csc = M1_csr * M2_csr + // This multiplication will be computationally inefficient, as it will require + // additional conversion of the output matrix from CSR to CSC format. + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsc(mat2); - const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csc = M1_csr * M2_csc + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } else { - const Eigen::Map mat1_eigen = - Tensor_to_EigenCsc(mat1); - if (mat2.layout() == kSparseCsc) { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsc(mat2); - const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + const auto mat1_eigen = Tensor_to_Eigen(mat1); + if (mat2.layout() == kSparseCsr) { + // Out_csc = M1_csc * M2_csr + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { - const Eigen::Map mat2_eigen = - Tensor_to_EigenCsr(mat2); - const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); - EigenCsc_to_Tensor(mat1_mat2, mat1_mat2_eigen); + // Out_csc = M1_csc * M2_csc + const auto mat2_eigen = Tensor_to_Eigen(mat2); + const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } } From 77e97b6487f5b0aeb590589f37129b046073c066 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 7 Aug 2025 19:32:12 +0400 Subject: [PATCH 7/8] Fix lint --- .../native/sparse/eigen/SparseBlasImpl.cpp | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index a1da733aa110c..bbfb692d40bf4 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -25,12 +25,12 @@ void inline sparse_indices_to_result_dtype_inplace( const at::Tensor& input) { auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(input); - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - compressed_indices.to(dtype), - plain_indices.to(dtype), - input.values(), - input.sizes()); + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices.to(dtype), + plain_indices.to(dtype), + input.values(), + input.sizes()); } void inline sparse_indices_and_values_resize( @@ -38,12 +38,12 @@ void inline sparse_indices_and_values_resize( int64_t nnz) { auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(input); - static_cast(input.unsafeGetTensorImpl()) - ->set_member_tensors( - compressed_indices, - plain_indices.resize_({nnz}), - input.values().resize_({nnz}), - input.sizes()); + static_cast(input.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices, + plain_indices.resize_({nnz}), + input.values().resize_({nnz}), + input.sizes()); } template @@ -245,7 +245,7 @@ void addmm_out_sparse_eigen( }); } -} // anonymus namespace +} // anonymous namespace void addmm_out_sparse( const at::Tensor& mat1, From 6f01edd80878da7707dc7e2e800f88ced4253b2c Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Tue, 12 Aug 2025 16:15:05 +0400 Subject: [PATCH 8/8] use explicit type declarations --- .../native/sparse/eigen/SparseBlasImpl.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp index bbfb692d40bf4..20738992a61d9 100644 --- a/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp @@ -168,6 +168,9 @@ void addmm_out_sparse_eigen( AT_DISPATCH_INDEX_TYPES( result_index_dtype, "eigen_sparse_mm", [&]() { + typedef Eigen::SparseMatrix EigenCsrMatrix; + typedef Eigen::SparseMatrix EigenCscMatrix; + at::Tensor mat1_mat2; if (is_beta_zero) { mat1_mat2 = result; @@ -181,12 +184,12 @@ void addmm_out_sparse_eigen( if (mat2.layout() == kSparseCsr) { // Out_csr = M1_csr * M2_csr const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { // Out_csr = M1_csr * M2_csc const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } else { @@ -194,14 +197,14 @@ void addmm_out_sparse_eigen( if (mat2.layout() == kSparseCsr) { // Out_csr = M1_csc * M2_csr const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { // Out_csr = M1_csc * M2_csc // This multiplication will be computationally inefficient, as it will require // additional conversion of the output matrix from CSC to CSR format. const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } @@ -213,12 +216,12 @@ void addmm_out_sparse_eigen( // This multiplication will be computationally inefficient, as it will require // additional conversion of the output matrix from CSR to CSC format. const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { // Out_csc = M1_csr * M2_csc const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } } else { @@ -226,12 +229,12 @@ void addmm_out_sparse_eigen( if (mat2.layout() == kSparseCsr) { // Out_csc = M1_csc * M2_csr const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } else { // Out_csc = M1_csc * M2_csc const auto mat2_eigen = Tensor_to_Eigen(mat2); - const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen); + const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen); Eigen_to_Tensor(mat1_mat2, mat1_mat2_eigen); } }