-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
MAINT: Remove similar branches from linalg.lstsq #9986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
# as documented | ||
if rank != n or m <= n: | ||
resids = array([], result_real_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bizarre interface, and resids
already contains 0
in the m <= n
case, which is a more meaningful way to say "no residual" than []
. But we're stuck with it, because that's how it's documented.
dtype=result_real_t) | ||
else: | ||
resids = array([sum((ravel(bstar)[n:])**2)], | ||
dtype=result_real_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what makes no sense at all, this branch produces the same effect as the one that follows it.
b563394
to
d813f1a
Compare
d813f1a
to
a311a8d
Compare
|
||
st = s[:min(n, m)].astype(result_real_t, copy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This slice was pointless, because len(s) == min(n,m)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks good. My comments are mostly nitpicks.
numpy/linalg/linalg.py
Outdated
@@ -1915,7 +1915,7 @@ def lstsq(a, b, rcond="warn"): | |||
x : {(N,), (N, K)} ndarray | |||
Least-squares solution. If `b` is two-dimensional, | |||
the solutions are in the `K` columns of `x`. | |||
residuals : {(), (1,), (K,)} ndarray | |||
residuals : {(0,), (1,), (K,)} ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just remove the (0,)
- now it is just confusing and the note states what happens for wrong input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be confusing, but that's because the behavior is confusing too!
And it's not even for "wrong input" - just for cases when an exact match is possible. I could move the (0,)
to the last item in the list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, maybe having it as the last item is best, since it is not the most common output.
numpy/linalg/linalg.py
Outdated
@@ -1997,8 +2001,6 @@ def lstsq(a, b, rcond="warn"): | |||
if rcond is None: | |||
rcond = finfo(t).eps * ldb | |||
|
|||
result_real_t = _realType(result_t) | |||
real_t = _linalgRealType(t) | |||
bstar = zeros((ldb, n_rhs), t) | |||
bstar[:b.shape[0], :n_rhs] = b.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not this PR, but when I looked at this before, I wondered what would be the point of .copy()
; it is not like a view gets taken and this cannot be of much speed benefit for the whole routine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the copy
stuff here is weird.
numpy/linalg/linalg.py
Outdated
x = b_out[:n,:] | ||
r_parts = b_out[n:,:] | ||
if isComplexType(t): | ||
resids = sum(abs(r_parts)**2, axis=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish we had a sensible power
or so, but to avoid a needless square root, one can do
sum(r_parts.real**2 + r_parts.imag**2, axis=-2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r_parts * r_parts.conj()
is probably a little faster, and also removes the branching. I'd rather leave this untouched though for now, since that would probably change results by a ULP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not (at least on my machine), but fine to let this be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not faster, or it's not guilty of introducing the ULP error?
Seems to me that there must be some value for which abs(x)**2 != x * x.conj()
. Of course, the x * x.conj()
value is closer to the true result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just meant that x*x.conj()
is slower than x.real**2 + x.imag**2
(which makes sense, as the former does a few useless multiplications that cancel). I do agree that there must be values of abs(x)**2
that are slightly less correct, given the sqrt and square after calculating x.real**2+x.imag**2
Anyway, fine to not worry about it here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've contemplated adding a ufunc for the squared absolute value, the main problem seems to be the name.
numpy/linalg/linalg.py
Outdated
resids = array([], result_real_t) | ||
|
||
# coerce output arrays | ||
s = s.astype(result_real_t, copy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the copy=True
here; s
is created in this routine, so no need to copy, it would seem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree - kept only because it was there before.
numpy/linalg/linalg.py
Outdated
# coerce output arrays | ||
s = s.astype(result_real_t, copy=True) | ||
resids = resids.astype(result_real_t, copy=False) # array is temporary | ||
x = x.astype(result_t, copy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x
is a view, so I guess it makes sense to copy. Maybe note that? (Also, copy=True
is the default.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copy=True
s confuse me, since as you note, they're the default. In fact, before #9888 there was a reasonable amount of code devoted to passing that argument.
Maybe this is trying to deal with a subclass that has a different default for copy
?
Comment seems reasonable here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's fine. If one were to design this from scratch, one would do the coercion only if an output array was given...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking back, the copy=False
arguments were introduced in #5909, and the =True
is deliberate and for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If one were to design this from scratch, one would do the coercion only if an output array was given
Or maybe just work with the dtype passed in, rather than always promoting to double before handing off to the ufunc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works only if one also uses different LAPACK routines (which is fine, of course), and would be less precise. But seems more logical in any case; just a different rcond
.
Do you want me to go through and remove the needless copies in another commit, or leave that for another PR? |
@eric-wieser - it is a bit up to you whether you want to bother. If you think you get to it anyway with the |
86489a4
to
69d5d6c
Compare
Nits addressed. |
This takes numpygh-5909 a little further.
69d5d6c
to
e3a50a9
Compare
Looks all OK. Maybe squash the commits? |
numpy/linalg/linalg.py
Outdated
b_out = bstar.T | ||
|
||
# b_out contains both the solution and the components of the residuals | ||
x = b_out[:n,:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8, no alignment like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, PEP8 shows a bunch of other whitespace violations in linalg.py
, so we could probably use a style PR to clean those up at some point.
LGTM apart from PEP8 nit. |
I'll fix that nit here and put this in. Thanks Eric. |
Working towards being able to fix #8720.
This doesn't change any behavior, but does add comments pointing out the existing somewhat-questionable behaviour