Skip to content

BUG: Fix norm type promotion #10667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

kmaehashi
Copy link
Member

I tested np.linalg.norm with the following code, but it seems the type is not preserved correctly.

import numpy as np

for dtype in [np.float16, np.float32, np.float64]:
    a = np.ones((2,), dtype=dtype)
    ret = np.linalg.norm(a, 3, None, False)
    if not (dtype == ret.dtype):
        print('expected', dtype, 'but got', ret.dtype)

Without this fix (NumPy 1.14.1):

expected <class 'numpy.float16'> but got float64
expected <class 'numpy.float32'> but got float64

With NumPy 1.14.1 + this fix, nothing should be printed.

Related to #10368 and cupy/cupy#875 (comment)

@eric-wieser
Copy link
Member

eric-wieser commented Feb 26, 2018

This doesn't affect the result dtype at all, AFAICT

**= operates in-place, so should not result in a type promotion (edit: except on scalars, where it does not)

@eric-wieser
Copy link
Member

eric-wieser commented Feb 26, 2018

Edit: I'm wrong.

Thanks to #10374, np.reciprocal is unusable in generic code. You can use np.true_divide(1, ord, dtype=...) here instead

I think that there's a deeper problem here with numpy type promotion, and how scalar op= val gives a different promotion to arr op= val

@kmaehashi
Copy link
Member Author

kmaehashi commented Feb 26, 2018

Ah, I see.

x = np.ndarray(1, dtype=np.float32)
x **= 1.0 / 2.0
x.dtype  # => float32

x = np.float32(1)
x **= 1.0 / 2.0
x.dtype  # => float64

Thanks to #10374, np.reciprocal is unusable in generic code.

I think the code in this PR is generic (i.e., works with both scalar and array).
Do you mean such use of np.reciprocal is discouraged (even with dtype=ret.dtype)?

import numpy as np
print('scalar', np.linalg.norm(np.ones((2,), dtype=np.float32), 3, 0, False).dtype)
print('array', np.linalg.norm(np.ones((2,2), dtype=np.float32), 3, 0, False).dtype)

Output from NumPy 1.14.1:

scalar float64
array float32

Output from NumPy 1.14.1 + this PR:

scalar float32
array float32

@charris charris changed the title fix norm type promotion BUG: Fix norm type promotion Feb 27, 2018
@eric-wieser
Copy link
Member

Fix looks good, but this needs a test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants