Skip to content

MAINT: Remove similar branches from linalg.lstsq #9986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 9, 2017

Conversation

eric-wieser
Copy link
Member

@eric-wieser eric-wieser commented Nov 8, 2017

Working towards being able to fix #8720.

This doesn't change any behavior, but does add comments pointing out the existing somewhat-questionable behaviour


# as documented
if rank != n or m <= n:
resids = array([], result_real_t)
Copy link
Member Author

@eric-wieser eric-wieser Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bizarre interface, and resids already contains 0 in the m <= n case, which is a more meaningful way to say "no residual" than []. But we're stuck with it, because that's how it's documented.

dtype=result_real_t)
else:
resids = array([sum((ravel(bstar)[n:])**2)],
dtype=result_real_t)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what makes no sense at all, this branch produces the same effect as the one that follows it.

@eric-wieser eric-wieser force-pushed the simplify-lstsq branch 2 times, most recently from b563394 to d813f1a Compare November 8, 2017 06:47

st = s[:min(n, m)].astype(result_real_t, copy=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slice was pointless, because len(s) == min(n,m)

Copy link
Contributor

@mhvk mhvk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks good. My comments are mostly nitpicks.

@@ -1915,7 +1915,7 @@ def lstsq(a, b, rcond="warn"):
x : {(N,), (N, K)} ndarray
Least-squares solution. If `b` is two-dimensional,
the solutions are in the `K` columns of `x`.
residuals : {(), (1,), (K,)} ndarray
residuals : {(0,), (1,), (K,)} ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just remove the (0,) - now it is just confusing and the note states what happens for wrong input.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be confusing, but that's because the behavior is confusing too!

And it's not even for "wrong input" - just for cases when an exact match is possible. I could move the (0,) to the last item in the list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe having it as the last item is best, since it is not the most common output.

@@ -1997,8 +2001,6 @@ def lstsq(a, b, rcond="warn"):
if rcond is None:
rcond = finfo(t).eps * ldb

result_real_t = _realType(result_t)
real_t = _linalgRealType(t)
bstar = zeros((ldb, n_rhs), t)
bstar[:b.shape[0], :n_rhs] = b.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not this PR, but when I looked at this before, I wondered what would be the point of .copy(); it is not like a view gets taken and this cannot be of much speed benefit for the whole routine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the copy stuff here is weird.

x = b_out[:n,:]
r_parts = b_out[n:,:]
if isComplexType(t):
resids = sum(abs(r_parts)**2, axis=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish we had a sensible power or so, but to avoid a needless square root, one can do
sum(r_parts.real**2 + r_parts.imag**2, axis=-2)

Copy link
Member Author

@eric-wieser eric-wieser Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r_parts * r_parts.conj() is probably a little faster, and also removes the branching. I'd rather leave this untouched though for now, since that would probably change results by a ULP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not (at least on my machine), but fine to let this be.

Copy link
Member Author

@eric-wieser eric-wieser Nov 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not faster, or it's not guilty of introducing the ULP error?

Seems to me that there must be some value for which abs(x)**2 != x * x.conj(). Of course, the x * x.conj() value is closer to the true result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just meant that x*x.conj() is slower than x.real**2 + x.imag**2 (which makes sense, as the former does a few useless multiplications that cancel). I do agree that there must be values of abs(x)**2 that are slightly less correct, given the sqrt and square after calculating x.real**2+x.imag**2

Anyway, fine to not worry about it here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've contemplated adding a ufunc for the squared absolute value, the main problem seems to be the name.

resids = array([], result_real_t)

# coerce output arrays
s = s.astype(result_real_t, copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the copy=True here; s is created in this routine, so no need to copy, it would seem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - kept only because it was there before.

# coerce output arrays
s = s.astype(result_real_t, copy=True)
resids = resids.astype(result_real_t, copy=False) # array is temporary
x = x.astype(result_t, copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is a view, so I guess it makes sense to copy. Maybe note that? (Also, copy=True is the default.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy=Trues confuse me, since as you note, they're the default. In fact, before #9888 there was a reasonable amount of code devoted to passing that argument.

Maybe this is trying to deal with a subclass that has a different default for copy?

Comment seems reasonable here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's fine. If one were to design this from scratch, one would do the coercion only if an output array was given...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking back, the copy=False arguments were introduced in #5909, and the =True is deliberate and for clarity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one were to design this from scratch, one would do the coercion only if an output array was given

Or maybe just work with the dtype passed in, rather than always promoting to double before handing off to the ufunc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works only if one also uses different LAPACK routines (which is fine, of course), and would be less precise. But seems more logical in any case; just a different rcond.

@eric-wieser
Copy link
Member Author

Do you want me to go through and remove the needless copies in another commit, or leave that for another PR?

@mhvk
Copy link
Contributor

mhvk commented Nov 8, 2017

@eric-wieser - it is a bit up to you whether you want to bother. If you think you get to it anyway with the gufunc implementation, I'm happy also to just merge this.

@eric-wieser
Copy link
Member Author

Nits addressed.

@mhvk
Copy link
Contributor

mhvk commented Nov 9, 2017

Looks all OK. Maybe squash the commits?

b_out = bstar.T

# b_out contains both the solution and the components of the residuals
x = b_out[:n,:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8, no alignment like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, PEP8 shows a bunch of other whitespace violations in linalg.py, so we could probably use a style PR to clean those up at some point.

@charris
Copy link
Member

charris commented Nov 9, 2017

LGTM apart from PEP8 nit.

@charris
Copy link
Member

charris commented Nov 9, 2017

I'll fix that nit here and put this in. Thanks Eric.

@charris charris merged commit d185ece into numpy:master Nov 9, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: broadcast lstsq
3 participants