Skip to content

Commit 2c69b71

Browse files
committed
apply review suggestions
1 parent 014f8d3 commit 2c69b71

File tree

1 file changed

+88
-169
lines changed

1 file changed

+88
-169
lines changed

aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp

Lines changed: 88 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -23,139 +23,75 @@ namespace {
2323
void inline sparse_indices_to_result_dtype_inplace(
2424
const c10::ScalarType& dtype,
2525
const at::Tensor& input) {
26-
if (input.layout() == kSparseCsr) {
27-
static_cast<at::SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
28-
->set_member_tensors(
29-
input.crow_indices().to(dtype),
30-
input.col_indices().to(dtype),
31-
input.values(),
32-
input.sizes());
33-
} else {
34-
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
35-
->set_member_tensors(
36-
input.ccol_indices().to(dtype),
37-
input.row_indices().to(dtype),
38-
input.values(),
39-
input.sizes());
40-
}
26+
auto [compressed_indices, plain_indices] =
27+
at::sparse_csr::getCompressedPlainIndices(input);
28+
static_cast<at::SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
29+
->set_member_tensors(
30+
compressed_indices.to(dtype),
31+
plain_indices.to(dtype),
32+
input.values(),
33+
input.sizes());
4134
}
4235

4336
void inline sparse_indices_and_values_resize(
4437
const at::Tensor& input,
4538
int64_t nnz) {
46-
if (input.layout() == kSparseCsr) {
47-
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
48-
->set_member_tensors(
49-
input.crow_indices(),
50-
input.col_indices().resize_({nnz}),
51-
input.values().resize_({nnz}),
52-
input.sizes());
53-
} else {
54-
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
55-
->set_member_tensors(
56-
input.ccol_indices(),
57-
input.row_indices().resize_({nnz}),
58-
input.values().resize_({nnz}),
59-
input.sizes());
60-
}
39+
auto [compressed_indices, plain_indices] =
40+
at::sparse_csr::getCompressedPlainIndices(input);
41+
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
42+
->set_member_tensors(
43+
compressed_indices,
44+
plain_indices.resize_({nnz}),
45+
input.values().resize_({nnz}),
46+
input.sizes());
6147
}
6248

63-
template <typename scalar_t, typename index_t>
64-
const Eigen::Map<Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t>>
65-
Tensor_to_EigenCsc(const at::Tensor& tensor) {
49+
template <typename scalar_t, int eigen_options, typename index_t>
50+
const Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>>
51+
Tensor_to_Eigen(const at::Tensor& tensor) {
6652
int64_t rows = tensor.size(0);
6753
int64_t cols = tensor.size(1);
6854
int64_t nnz = tensor._nnz();
69-
70-
TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors");
71-
72-
index_t* ccol_indices_ptr = tensor.ccol_indices().data_ptr<index_t>();
73-
index_t* row_indices_ptr = tensor.row_indices().data_ptr<index_t>();
55+
TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensor values");
56+
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
57+
index_t* c_indices_ptr = compressed_indices.data_ptr<index_t>();
58+
index_t* p_indices_ptr = plain_indices.data_ptr<index_t>();
7459
scalar_t* values_ptr = tensor.values().data_ptr<scalar_t>();
75-
Eigen::Map<Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t>> map(
76-
rows, cols, nnz, ccol_indices_ptr, row_indices_ptr, values_ptr);
60+
Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>> map(
61+
rows, cols, nnz, c_indices_ptr, p_indices_ptr, values_ptr);
7762
return map;
7863
}
7964

80-
template <typename scalar_t, typename index_t>
81-
const Eigen::Map<Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t>>
82-
Tensor_to_EigenCsr(const at::Tensor& tensor) {
83-
int64_t rows = tensor.size(0);
84-
int64_t cols = tensor.size(1);
85-
int64_t nnz = tensor._nnz();
86-
87-
TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensors");
88-
89-
index_t* crow_indices_ptr = tensor.crow_indices().data_ptr<index_t>();
90-
index_t* col_indices_ptr = tensor.col_indices().data_ptr<index_t>();
91-
scalar_t* values_ptr = tensor.values().data_ptr<scalar_t>();
92-
Eigen::Map<Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t>> map(
93-
rows, cols, nnz, crow_indices_ptr, col_indices_ptr, values_ptr);
94-
return map;
95-
}
96-
97-
template <typename scalar_t, typename index_t>
98-
void EigenCsr_to_Tensor(
65+
template <typename scalar_t, int eigen_options, typename index_t>
66+
void Eigen_to_Tensor(
9967
const at::Tensor& tensor,
100-
const Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t>& matrix) {
68+
const Eigen::SparseMatrix<scalar_t, eigen_options, index_t>& matrix) {
69+
const Layout eigen_layout = (eigen_options == Eigen::RowMajor ? kSparseCsr : kSparseCsc);
10170
TORCH_CHECK(
102-
tensor.layout() == kSparseCsr,
103-
"EigenCsr_to_Tensor, expected tensor be kSparseCsr, but got",
71+
tensor.layout() == eigen_layout,
72+
"Eigen_to_Tensor, expected tensor be ", eigen_layout, ", but got ",
10473
tensor.layout());
105-
10674
int64_t nnz = matrix.nonZeros();
107-
int64_t rows = matrix.outerSize();
75+
int64_t csize = matrix.outerSize();
10876
sparse_indices_and_values_resize(tensor, nnz);
109-
77+
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
11078
if (nnz > 0) {
11179
std::memcpy(
11280
tensor.values().mutable_data_ptr<scalar_t>(),
11381
matrix.valuePtr(),
11482
nnz * sizeof(scalar_t));
11583
std::memcpy(
116-
tensor.col_indices().mutable_data_ptr<index_t>(),
84+
plain_indices.mutable_data_ptr<index_t>(),
11785
matrix.innerIndexPtr(),
11886
nnz * sizeof(index_t));
11987
}
120-
if (rows > 0) {
88+
if (csize > 0) {
12189
std::memcpy(
122-
tensor.crow_indices().mutable_data_ptr<index_t>(),
90+
compressed_indices.mutable_data_ptr<index_t>(),
12391
matrix.outerIndexPtr(),
124-
rows * sizeof(index_t));
92+
csize * sizeof(index_t));
12593
}
126-
tensor.crow_indices().mutable_data_ptr<index_t>()[rows] = nnz;
127-
}
128-
129-
template <typename scalar_t, typename index_t>
130-
void EigenCsc_to_Tensor(
131-
const at::Tensor& tensor,
132-
const Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t>& matrix) {
133-
TORCH_CHECK(
134-
tensor.layout() == kSparseCsc,
135-
"EigenCsr_to_Tensor, expected tensor be kSparseCsc, but got",
136-
tensor.layout());
137-
138-
int64_t nnz = matrix.nonZeros();
139-
int64_t cols = matrix.outerSize();
140-
sparse_indices_and_values_resize(tensor, nnz);
141-
142-
if (nnz > 0) {
143-
std::memcpy(
144-
tensor.values().mutable_data_ptr<scalar_t>(),
145-
matrix.valuePtr(),
146-
nnz * sizeof(scalar_t));
147-
std::memcpy(
148-
tensor.row_indices().mutable_data_ptr<index_t>(),
149-
matrix.innerIndexPtr(),
150-
nnz * sizeof(index_t));
151-
}
152-
if (cols > 0) {
153-
std::memcpy(
154-
tensor.ccol_indices().mutable_data_ptr<index_t>(),
155-
matrix.outerIndexPtr(),
156-
cols * sizeof(index_t));
157-
}
158-
tensor.ccol_indices().mutable_data_ptr<index_t>()[cols] = nnz;
94+
compressed_indices.mutable_data_ptr<index_t>()[csize] = nnz;
15995
}
16096

16197
template <typename scalar_t>
@@ -188,29 +124,17 @@ void add_out_sparse_eigen(
188124
AT_DISPATCH_INDEX_TYPES(
189125
result_index_dtype, "eigen_sparse_add", [&]() {
190126
scalar_t _alpha = alpha.to<scalar_t>();
191-
typedef Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t>
192-
EigenCscMatrix;
193-
typedef Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t>
194-
EigenCsrMatrix;
195127

196128
if (result.layout() == kSparseCsr) {
197-
const Eigen::Map<EigenCsrMatrix> mat1_eigen =
198-
Tensor_to_EigenCsr<scalar_t, index_t>(mat1);
199-
const Eigen::Map<EigenCsrMatrix> mat2_eigen =
200-
Tensor_to_EigenCsr<scalar_t, index_t>(mat2);
201-
const EigenCsrMatrix mat1_mat2_eigen =
202-
(mat1_eigen + _alpha * mat2_eigen);
203-
204-
EigenCsr_to_Tensor<scalar_t, index_t>(result, mat1_mat2_eigen);
129+
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
130+
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
131+
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
132+
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(result, mat1_mat2_eigen);
205133
} else {
206-
const Eigen::Map<EigenCscMatrix> mat1_eigen =
207-
Tensor_to_EigenCsc<scalar_t, index_t>(mat1);
208-
const Eigen::Map<EigenCscMatrix> mat2_eigen =
209-
Tensor_to_EigenCsc<scalar_t, index_t>(mat2);
210-
const EigenCscMatrix mat1_mat2_eigen =
211-
(mat1_eigen + _alpha * mat2_eigen);
212-
213-
EigenCsc_to_Tensor<scalar_t, index_t>(result, mat1_mat2_eigen);
134+
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
135+
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
136+
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
137+
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(result, mat1_mat2_eigen);
214138
}
215139
});
216140
}
@@ -244,11 +168,6 @@ void addmm_out_sparse_eigen(
244168

245169
AT_DISPATCH_INDEX_TYPES(
246170
result_index_dtype, "eigen_sparse_mm", [&]() {
247-
typedef Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t>
248-
EigenCscMatrix;
249-
typedef Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t>
250-
EigenCsrMatrix;
251-
252171
at::Tensor mat1_mat2;
253172
if (is_beta_zero) {
254173
mat1_mat2 = result;
@@ -258,62 +177,62 @@ void addmm_out_sparse_eigen(
258177

259178
if (mat1_mat2.layout() == kSparseCsr) {
260179
if (mat1.layout() == kSparseCsr) {
261-
const Eigen::Map<EigenCsrMatrix> mat1_eigen =
262-
Tensor_to_EigenCsr<scalar_t, index_t>(mat1);
180+
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
263181
if (mat2.layout() == kSparseCsr) {
264-
const Eigen::Map<EigenCsrMatrix> mat2_eigen =
265-
Tensor_to_EigenCsr<scalar_t, index_t>(mat2);
266-
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
267-
EigenCsr_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
182+
// Out_csr = M1_csr * M2_csr
183+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
184+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
185+
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
268186
} else {
269-
const Eigen::Map<EigenCscMatrix> mat2_eigen =
270-
Tensor_to_EigenCsc<scalar_t, index_t>(mat2);
271-
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
272-
EigenCsr_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
187+
// Out_csr = M1_csr * M2_csc
188+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
189+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
190+
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
273191
}
274192
} else {
275-
const Eigen::Map<EigenCscMatrix> mat1_eigen =
276-
Tensor_to_EigenCsc<scalar_t, index_t>(mat1);
277-
if (mat2.layout() == kSparseCsc) {
278-
const Eigen::Map<EigenCscMatrix> mat2_eigen =
279-
Tensor_to_EigenCsc<scalar_t, index_t>(mat2);
280-
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
281-
EigenCsr_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
193+
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
194+
if (mat2.layout() == kSparseCsr) {
195+
// Out_csr = M1_csc * M2_csr
196+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
197+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
198+
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
282199
} else {
283-
const Eigen::Map<EigenCsrMatrix> mat2_eigen =
284-
Tensor_to_EigenCsr<scalar_t, index_t>(mat2);
285-
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
286-
EigenCsr_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
200+
// Out_csr = M1_csc * M2_csc
201+
// This multiplication will be computationally inefficient, as it will require
202+
// additional conversion of the output matrix from CSC to CSR format.
203+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
204+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
205+
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
287206
}
288207
}
289208
} else {
290209
if (mat1.layout() == kSparseCsr) {
291-
const Eigen::Map<EigenCsrMatrix> mat1_eigen =
292-
Tensor_to_EigenCsr<scalar_t, index_t>(mat1);
210+
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
293211
if (mat2.layout() == kSparseCsr) {
294-
const Eigen::Map<EigenCsrMatrix> mat2_eigen =
295-
Tensor_to_EigenCsr<scalar_t, index_t>(mat2);
296-
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
297-
EigenCsc_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
212+
// Out_csc = M1_csr * M2_csr
213+
// This multiplication will be computationally inefficient, as it will require
214+
// additional conversion of the output matrix from CSR to CSC format.
215+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
216+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
217+
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
298218
} else {
299-
const Eigen::Map<EigenCscMatrix> mat2_eigen =
300-
Tensor_to_EigenCsc<scalar_t, index_t>(mat2);
301-
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
302-
EigenCsc_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
219+
// Out_csc = M1_csr * M2_csc
220+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
221+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
222+
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
303223
}
304224
} else {
305-
const Eigen::Map<EigenCscMatrix> mat1_eigen =
306-
Tensor_to_EigenCsc<scalar_t, index_t>(mat1);
307-
if (mat2.layout() == kSparseCsc) {
308-
const Eigen::Map<EigenCscMatrix> mat2_eigen =
309-
Tensor_to_EigenCsc<scalar_t, index_t>(mat2);
310-
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
311-
EigenCsc_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
225+
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
226+
if (mat2.layout() == kSparseCsr) {
227+
// Out_csc = M1_csc * M2_csr
228+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
229+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
230+
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
312231
} else {
313-
const Eigen::Map<EigenCsrMatrix> mat2_eigen =
314-
Tensor_to_EigenCsr<scalar_t, index_t>(mat2);
315-
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
316-
EigenCsc_to_Tensor<scalar_t, index_t>(mat1_mat2, mat1_mat2_eigen);
232+
// Out_csc = M1_csc * M2_csc
233+
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
234+
const auto mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
235+
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
317236
}
318237
}
319238
}

0 commit comments

Comments
 (0)