-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Make csr row norms support fused types #6785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Make csr row norms support fused types #6785
Conversation
Tests are failing, you have to reduce the precision for the sqrt also. |
np.ndarray[int, ndim=1, mode="c"] indptr = X.indptr | ||
unsigned int n_samples = shape[0] | ||
unsigned int n_features = shape[1] | ||
np.ndarray[double, ndim=1, mode="c"] norms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this no longer DOUBLE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean your comment or the dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it's my mistake.
BTW, since type double
comes from import numpy as np
and typenp.float64_t
comes from cimport numpy as np
,
is there a big performance difference between them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how double
comes from import numpy as np
. double
is a C type whose size is officially unspecified, but to be honest I'm not sure if that's the only reason the more precise numpy type is preferred over the C type.
8930eb6
to
ad211a1
Compare
@jnothman Sorry, my execution time test before is not correct. After debugging and running it again, result shows that running time increases from 9.9s to 13s if we explicitly cast every entry as we multiply it. It seems that it indeed cause a big runtime hit ... |
However, test can be passed if I change test's precision from 1e-6 to 1e-4 as @TomDLT suggested. |
Oh. I thought I'd replied to this. But perhaps for the same reason than I'm still not sure what to say, I didn't. I did briefly try to look for a BLAS or similar function (I'm not very familiar with what's available) that might support fast sum-of-squares, potentially with a result in higher precision than the input. I agree that 30% seems a substantial runtime hit. I suppose we can accept the loss in precision. :/ |
|
||
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 5) | ||
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X)) | ||
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a good idea to test for the float64
and float32
dtype separately, with a higher precision for the float64 dtype. WDYT?
Just a minor comment, +1 otherwise |
ad211a1
to
c7d6f9f
Compare
@jnothman merge? |
Let's do it. |
Since
csr_row_norms
is called byrow_norms
function defined insklearn/utils/tests/test_extmath.py
, androw_norms
is used widely ink_means_.py
,it will be useful if
csr_row_norms
function also supports cython fused types.However, making this change would degrade the precision of the function.
In order to pass the local test, I have to alleviate the strictness by changing
assert_array_almost_equal
's last n decimal digit from 5 to 4.May @MechCoder and @jnothman give me some advice on this trade-off?