-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
ENH, MAINT: Refactor PyArray_InnerProduct
to use PyArray_MatrixProduct2
#6968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Yeah, that is what I mean. I will let someone else decide further, but I think the transpose may need to be a bit more complicated. To figure that out/test it, can you add a test for multiplying larger then 2D arrays? |
770b9bc
to
9e8a47f
Compare
My understanding is that this isn't intended for public consumption as it only works on 2D arrays or smaller just as This only gets called in one place ( numpy/numpy/core/src/multiarray/multiarraymodule.c Lines 848 to 854 in eeba2cb
cblas_innerproduct .
|
Ah, right. Frankly, I would much prefer if we can do it for PyArray_InnerProduct making it just call PyArray_MatrixProduct2. Then remove the whole cblas_innerproduct function. Though I am not sure it makes sense, my guess would be it does. |
It looks like there are some Other things this probably needs are some benchmarks as I am claiming one can now get a speedup with |
It is probably removable, but what I liked about your original proposal (at least how I read it) is everything stays pretty much the same in terms of how everything builds. Also, how many, which, and where functions are remains the same. As soon as we start messing with that, we open ourselves up to spending time hunting down potentially weird build errors. At present, this just works. |
Well, look at MatrixProduct2 ;), it is a monster of annoying stuff and all that is needed to call it from inner is transposing the last two axes of op2 (if op2.ndim >= 2). The added complexity is only that the transpose is more complicated because you cannot pass in NULL. Other then that, you just remove more complexity. Note that the function you changed can be removed completely, it is not public. |
It would be nice if there was a |
Well, I don't know, you probably have to go way back to figure that one out ;). That is why I would like it refactored away. You save a single tranpose (array creation) for all I figure. |
Sorry, I realized I was asking a question that was basically becoming "why do we have |
BTW, did you see this ( #5859 )? |
cblas_innerproduct
to use cblas_matrixproduct
cblas_innerproduct
to use cblas_matrixproduct
6a4903e
to
d2ea634
Compare
So, if I make this simple change ( jakirkham@2288e34 ), I get a segmentation fault in the test suite. Is there something that I am doing wrong here? I'm not very familiar with the C API so I wouldn't be surprised if I am. I just need a few pointers. |
d2ea634
to
81d4275
Compare
Can you do me the favor and try to refactor all of PyArray_Inner? At least to me doing the transpose specifically in this subfunction seems half baked. Doing the full transpose should not be too difficult, see for example https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/mapping.c#L135 for a "complex" Transpose operation, but it gives you the idea that you can just iterate and set all ndim (if ndim >= 2), and then switch the last two ndim around. |
So, I actually did that too. Though I have not pushed it to the PR yet. I am struggling with a segmentation fault there, as well. Here is the commit ( jakirkham@8f5464c ). Any pointers on why the segmentation fault occurs in either case would be helpful. |
cblas_innerproduct
to use cblas_matrixproduct
cblas_innerproduct
to use cblas_matrixproduct
a5220ef
to
c616a02
Compare
Tried to add some tests for the exception, but they fail |
c616a02
to
0c708d8
Compare
So, I have removed the changed exception as it appears to be somewhat controversial and it is in a different PR ( #6987 ) with exception tests. I have also removed the exceptions tests as they seem to have issues that at the moment I cannot seem to figure out and placed them in another PR ( #6988 ). If we can get these working, I am willing to combine them back into this PR. However, I don't want this PR's fate to be determined by trying to come up with acceptable solutions to these more minor issues. |
Yes, they went in the PR ( #6986 ), which is already merged. This PR has been rebased on master after those commits were added so includes them. |
df15b9a
to
920b70c
Compare
Failures on AppVeyor ( #6991 ) had not previously occurred for this PR with exception of the segmentation fault issue (also happened on Travis), which has since been resolved. I believe these to be unrelated to the content of this PR. |
920b70c
to
2ba0898
Compare
As AppVeyor is merging with master, which is broken, it is currently failing the tests there. So, I ran my own AppVeyor build without this merge (this is rebased on a commit on |
…to a common type.
…into a common type.
…Object_Repr`. Also, do a better job of handling any errors raised while constructing the error message.
…nspose and calls `PyArray_MatrixProduct2`.
2ba0898
to
223513a
Compare
ENH, MAINT: Refactor `PyArray_InnerProduct` to use `PyArray_MatrixProduct2`
Thanks @jakirkham . The |
Thanks! |
Thanks everyone. Alright, @charris, I'll try to look at this at some point soon. |
The benchmark shows the |
Very nice! |
Fixes #6948
Related: #6932
Related: #6977
Related: #6986
Related: #6987
Related: #6988
This follows @seberg's suggestion ( #6948 (comment) ) and does the simplest thing. Namely, refactors
PyArray_InnerProduct
to usePyArray_MatrixProduct2
. As a consequence,np.inner
will see a speedup in cases where the problem contains portions like thesea @ a.T
anda.T @ a
as this already is optimized incblas_matrixproduct
because of this PR ( #6932 ).Todo: - [x] Refactor so that `PyArray_InnerProduct` just calls `PyArray_MatrixProduct2`. - [x] Add more tests of different dimensions for `np.inner` to make sure it still behaves correctly. - [x] Add benchmarks so that cases where `np.inner` can now see a speedup are shown vs those it can't as was done with `np.dot`. - [x] Add test for type mismatch exception.