-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Make csr row norms support fused types #6785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,24 +23,29 @@ ctypedef np.float64_t DOUBLE | |
|
||
def csr_row_norms(X): | ||
"""L2 norm of each row in CSR matrix X.""" | ||
if X.dtype != np.float32: | ||
X = X.astype(np.float64) | ||
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr) | ||
|
||
|
||
def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data, | ||
shape, | ||
np.ndarray[int, ndim=1, mode="c"] X_indices, | ||
np.ndarray[int, ndim=1, mode="c"] X_indptr): | ||
cdef: | ||
unsigned int n_samples = X.shape[0] | ||
unsigned int n_features = X.shape[1] | ||
unsigned int n_samples = shape[0] | ||
unsigned int n_features = shape[1] | ||
np.ndarray[DOUBLE, ndim=1, mode="c"] norms | ||
np.ndarray[DOUBLE, ndim=1, mode="c"] data | ||
np.ndarray[int, ndim=1, mode="c"] indices = X.indices | ||
np.ndarray[int, ndim=1, mode="c"] indptr = X.indptr | ||
|
||
np.npy_intp i, j | ||
double sum_ | ||
|
||
norms = np.zeros(n_samples, dtype=np.float64) | ||
data = np.asarray(X.data, dtype=np.float64) # might copy! | ||
|
||
for i in range(n_samples): | ||
sum_ = 0.0 | ||
for j in range(indptr[i], indptr[i + 1]): | ||
sum_ += data[j] * data[j] | ||
for j in range(X_indptr[i], X_indptr[i + 1]): | ||
sum_ += X_data[j] * X_data[j] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you cast each entry to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. About this, I've timed the Original: 9.901501894
You can find my test script here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably won't be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My suggestion is that you can use float64 |
||
norms[i] = sum_ | ||
|
||
return norms | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to specify the type of
norms
here, i.enp.ndarray
. There is some non-negligible overhead by not doing that. Btw, you can check that by compiling using thecython -a
flag and checking the html.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is just above, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah.