-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Fix for float16 overflow on accumulator operations #13010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @baluyotraf !
It might be good to add a non regression test for overflow in StandardScaler with float16.
sklearn/utils/__init__.py
Outdated
# Use at least float64 for the accumulating functions to avoid precision issues; | ||
# see https://github.com/numpy/numpy/issues/9393 | ||
# The float64 is also retained as it is in case the float overflows | ||
def safe_acc_op(op, x, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this private, maybe more verbose (_safe_accumulate_op
) and move it to utils.extmath
.
…tils.extmath. Also fixed some line lengths to fit the 80 limit (scikit-learn#13007)
Moved the function to extmath and added the test. I also verified that the test fails on master and that it passes in this branch. Thanks for the review. o/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Please add an entry to the change log at doc/whats_new/v0.21.rst
. Like the other entries there, please reference this pull request with :issue:
and credit yourself (and other contributors if applicable) with :user:
sklearn/utils/extmath.py
Outdated
@@ -723,7 +750,8 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): | |||
if last_variance is None: | |||
updated_variance = None | |||
else: | |||
new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count | |||
new_unnormalized_variance = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer line continuations to use parentheses rather than backslash where possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I saw a backslash someone so I kind of went along with it. I'll take note of this.
# Overflow calculations may cause -inf, inf, or nan. Since there is no nan | ||
# input, all of the outputs should be finite. This may be redundant since a | ||
# FloatingPointError exception will be thrown on overflow above. | ||
assert np.all(np.isfinite(X_scaled)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes more sense to check that the output is identical to when the input is high precision. Also may want to check that the scaler features are preserving the input dtype (although surely we have another test for that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it out before and found that output is off after 2 or 3 decimal points. Should we cast the input during fit and cast it back to float16? It's kind of similar with to #12333 only this time the imprecision is with the results rather than the mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't you expect it to be off after 2 or 3 decimal points with float16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a test like this be enough?
def test_scaler_float16_overflow():
# Test if the scaler will not overflow on float16 numpy arrays
rng = np.random.RandomState(0)
# float16 has a maximum of 65500.0. On the worst case 5 * 200000 is 100000
# which is enough to overflow the data type
X = rng.uniform(5, 10, [200000, 1]).astype(np.float16)
with np.errstate(over='raise'):
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)
# Calculate the float64 equivalent to verify result
X_scaled_f64 = StandardScaler().fit_transform(X.astype(np.float64))
# Overflow calculations may cause -inf, inf, or nan. Since there is no nan
# input, all of the outputs should be finite. This may be redundant since a
# FloatingPointError exception will be thrown on overflow above.
assert np.all(np.isfinite(X_scaled))
# The normal distribution is very unlikely to go above 4. At 4.0-8.0 the
# float16 precision is 2^-8 which is around 0.004. Thus only 2 decimals are
# checked to account for precision differences.
assert_array_almost_equal(X_scaled, X_scaled_f64, decimal=2)
There are CI failures, btw. |
Kind of you to show your working. Looks great (especially if it also
passes)!
|
…ult with respect to their precisions (scikit-learn#13007)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @baluyotraf !
this did not fix #5602 ? |
…ler (scikit-learn#13010)" This reverts commit 2ff7649.
…ler (scikit-learn#13010)" This reverts commit 2ff7649.
Reference Issues/PRs
This fixes #13007
What does this implement/fix? Explain your changes.
A dtype of float64 is passed when using numpy based accumulator functions to prevent overflow. This is only done for floating point inputs.