13
13
from . import _sparsetools
14
14
from ._sparsetools import (get_csr_submatrix , csr_sample_offsets , csr_todense ,
15
15
csr_sample_values , csr_row_index , csr_row_slice ,
16
- csr_column_index1 , csr_column_index2 )
16
+ csr_column_index1 , csr_column_index2 , csr_diagonal ,
17
+ expandptr , csr_has_canonical_format , csr_eliminate_zeros ,
18
+ csr_sum_duplicates , csr_has_sorted_indices , csr_sort_indices ,
19
+ csr_matmat_maxnnz , csr_matmat )
17
20
from ._index import IndexMixin
18
21
from ._sputils import (upcast , upcast_char , to_native , isdense , isshape ,
19
22
getdtype , isscalarlike , isintlike , downcast_intp_index ,
@@ -560,39 +563,40 @@ def _matmul_sparse(self, other):
560
563
new_shape += (N ,)
561
564
faux_shape = (M if self .ndim == 2 else 1 , N if o_ndim == 2 else 1 )
562
565
563
- major_dim = self ._swap ((M , N ))[0 ]
564
566
other = self .__class__ (other ) # convert to this format
567
+ index_arrays = (self .indptr , self .indices , other .indptr , other .indices )
565
568
566
- idx_dtype = self ._get_index_dtype ((self .indptr , self .indices ,
567
- other .indptr , other .indices ))
568
-
569
- fn = getattr (_sparsetools , self .format + '_matmat_maxnnz' )
570
- nnz = fn (M , N ,
571
- np .asarray (self .indptr , dtype = idx_dtype ),
572
- np .asarray (self .indices , dtype = idx_dtype ),
573
- np .asarray (other .indptr , dtype = idx_dtype ),
574
- np .asarray (other .indices , dtype = idx_dtype ))
569
+ M , N = self ._swap ((M , N ))
570
+ s , o = self ._swap ((self , other ))
571
+
572
+ idx_dtype = self ._get_index_dtype (index_arrays )
573
+ s_indptr = np .asarray (s .indptr , dtype = idx_dtype )
574
+ s_indices = np .asarray (s .indices , dtype = idx_dtype )
575
+ o_indptr = np .asarray (o .indptr , dtype = idx_dtype )
576
+ o_indices = np .asarray (o .indices , dtype = idx_dtype )
577
+
578
+ nnz = csr_matmat_maxnnz (M , N , s_indptr , s_indices , o_indptr , o_indices )
575
579
if nnz == 0 :
576
580
if new_shape == ():
577
581
return np .array (0 , dtype = upcast (self .dtype , other .dtype ))
578
582
return self .__class__ (new_shape , dtype = upcast (self .dtype , other .dtype ))
579
583
580
- idx_dtype = self ._get_index_dtype ((self .indptr , self .indices ,
581
- other .indptr , other .indices ),
582
- maxval = nnz )
584
+ new_idx_dtype = self ._get_index_dtype (index_arrays , maxval = nnz )
585
+ if new_idx_dtype != idx_dtype :
586
+ idx_dtype = new_idx_dtype
587
+ s_indptr = np .asarray (s .indptr , dtype = idx_dtype )
588
+ s_indices = np .asarray (s .indices , dtype = idx_dtype )
589
+ o_indptr = np .asarray (o .indptr , dtype = idx_dtype )
590
+ o_indices = np .asarray (o .indices , dtype = idx_dtype )
583
591
584
- indptr = np .empty (major_dim + 1 , dtype = idx_dtype )
592
+ indptr = np .empty (M + 1 , dtype = idx_dtype )
585
593
indices = np .empty (nnz , dtype = idx_dtype )
586
594
data = np .empty (nnz , dtype = upcast (self .dtype , other .dtype ))
587
595
588
- fn = getattr (_sparsetools , self .format + '_matmat' )
589
- fn (M , N , np .asarray (self .indptr , dtype = idx_dtype ),
590
- np .asarray (self .indices , dtype = idx_dtype ),
591
- self .data ,
592
- np .asarray (other .indptr , dtype = idx_dtype ),
593
- np .asarray (other .indices , dtype = idx_dtype ),
594
- other .data ,
595
- indptr , indices , data )
596
+ csr_matmat (M , N ,
597
+ s_indptr , s_indices , s .data ,
598
+ o_indptr , o_indices , o .data ,
599
+ indptr , indices , data )
596
600
597
601
if new_shape == ():
598
602
return np .array (data [0 ])
@@ -604,14 +608,13 @@ def _matmul_sparse(self, other):
604
608
return res
605
609
606
610
def diagonal (self , k = 0 ):
607
- rows , cols = self .shape
608
- if k <= - rows or k >= cols :
611
+ M , N = self ._swap (self .shape )
612
+ k , _ = self ._swap ((k , - k ))
613
+
614
+ if k <= - M or k >= N :
609
615
return np .empty (0 , dtype = self .data .dtype )
610
- fn = getattr (_sparsetools , self .format + "_diagonal" )
611
- y = np .empty (min (rows + min (k , 0 ), cols - max (k , 0 )),
612
- dtype = upcast (self .dtype ))
613
- fn (k , self .shape [0 ], self .shape [1 ], self .indptr , self .indices ,
614
- self .data , y )
616
+ y = np .empty (min (M + min (k , 0 ), N - max (k , 0 )), dtype = upcast (self .dtype ))
617
+ csr_diagonal (k , M , N , self .indptr , self .indices , self .data , y )
615
618
return y
616
619
617
620
diagonal .__doc__ = _spbase .diagonal .__doc__
@@ -1151,7 +1154,7 @@ def tocoo(self, copy=True):
1151
1154
major_dim , minor_dim = self ._swap (self .shape )
1152
1155
minor_indices = self .indices
1153
1156
major_indices = np .empty (len (minor_indices ), dtype = self .indices .dtype )
1154
- _sparsetools . expandptr (major_dim , self .indptr , major_indices )
1157
+ expandptr (major_dim , self .indptr , major_indices )
1155
1158
coords = self ._swap ((major_indices , minor_indices ))
1156
1159
1157
1160
return self ._coo_container (
@@ -1189,7 +1192,7 @@ def eliminate_zeros(self):
1189
1192
This is an *in place* operation.
1190
1193
"""
1191
1194
M , N = self ._swap (self ._shape_as_2d )
1192
- _sparsetools . csr_eliminate_zeros (M , N , self .indptr , self .indices , self .data )
1195
+ csr_eliminate_zeros (M , N , self .indptr , self .indices , self .data )
1193
1196
self .prune () # nnz may have changed
1194
1197
1195
1198
@property
@@ -1209,10 +1212,10 @@ def has_canonical_format(self) -> bool:
1209
1212
# not sorted => not canonical
1210
1213
self ._has_canonical_format = False
1211
1214
elif not hasattr (self , '_has_canonical_format' ):
1215
+ M = len (self .indptr ) - 1
1212
1216
self .has_canonical_format = bool (
1213
- _sparsetools .csr_has_canonical_format (
1214
- len (self .indptr ) - 1 , self .indptr , self .indices )
1215
- )
1217
+ csr_has_canonical_format (M , self .indptr , self .indices )
1218
+ )
1216
1219
return self ._has_canonical_format
1217
1220
1218
1221
@has_canonical_format .setter
@@ -1231,7 +1234,7 @@ def sum_duplicates(self):
1231
1234
self .sort_indices ()
1232
1235
1233
1236
M , N = self ._swap (self ._shape_as_2d )
1234
- _sparsetools . csr_sum_duplicates (M , N , self .indptr , self .indices , self .data )
1237
+ csr_sum_duplicates (M , N , self .indptr , self .indices , self .data )
1235
1238
1236
1239
self .prune () # nnz may have changed
1237
1240
self .has_canonical_format = True
@@ -1246,10 +1249,10 @@ def has_sorted_indices(self) -> bool:
1246
1249
"""
1247
1250
# first check to see if result was cached
1248
1251
if not hasattr (self , '_has_sorted_indices' ):
1252
+ M = len (self .indptr ) - 1
1249
1253
self ._has_sorted_indices = bool (
1250
- _sparsetools .csr_has_sorted_indices (
1251
- len (self .indptr ) - 1 , self .indptr , self .indices )
1252
- )
1254
+ csr_has_sorted_indices (M , self .indptr , self .indices )
1255
+ )
1253
1256
return self ._has_sorted_indices
1254
1257
1255
1258
@has_sorted_indices .setter
@@ -1271,10 +1274,9 @@ def sorted_indices(self):
1271
1274
def sort_indices (self ):
1272
1275
"""Sort the indices of this array/matrix *in place*
1273
1276
"""
1274
-
1275
1277
if not self .has_sorted_indices :
1276
- _sparsetools . csr_sort_indices ( len (self .indptr ) - 1 , self . indptr ,
1277
- self .indices , self .data )
1278
+ M = len (self .indptr ) - 1
1279
+ csr_sort_indices ( M , self . indptr , self .indices , self .data )
1278
1280
self .has_sorted_indices = True
1279
1281
1280
1282
def prune (self ):
@@ -1353,7 +1355,7 @@ def _binopt(self, other, op):
1353
1355
other = self .__class__ (other )
1354
1356
1355
1357
# e.g. csr_plus_csr, csr_minus_csr, etc.
1356
- fn = getattr (_sparsetools , self . format + op + self . format )
1358
+ fn = getattr (_sparsetools , "csr" + op + "csr" )
1357
1359
1358
1360
maxnnz = self .nnz + other .nnz
1359
1361
idx_dtype = self ._get_index_dtype ((self .indptr , self .indices ,
@@ -1368,7 +1370,7 @@ def _binopt(self, other, op):
1368
1370
else :
1369
1371
data = np .empty (maxnnz , dtype = upcast (self .dtype , other .dtype ))
1370
1372
1371
- M , N = self ._shape_as_2d
1373
+ M , N = self ._swap ( self . _shape_as_2d )
1372
1374
fn (M , N ,
1373
1375
np .asarray (self .indptr , dtype = idx_dtype ),
1374
1376
np .asarray (self .indices , dtype = idx_dtype ),
0 commit comments