Skip to content

mean_squred_error giving wrong results #28827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
TheGuy42 opened this issue Apr 13, 2024 · 2 comments
Closed

mean_squred_error giving wrong results #28827

TheGuy42 opened this issue Apr 13, 2024 · 2 comments
Labels
Bug Needs Triage Issue requires triage

Comments

@TheGuy42
Copy link

Describe the bug

I have recently noticed a bug in the implementation of mean_squared_error in sklearn.metrics.
The current implementation of the function basically calculates the MSE as follows:

output_errors = np.average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)

Which is reasonable in most cases, but may return wrong results in cases that the type of y_true and y_pred has a low bit count, for example np.uint8 ranging from 0 to 254.
The reason for that is that when doing the calculation using arrays of types like np.uint8, it is very likely that overflows will occur (which are not reported in any way!) resulting in wrong results.
To resolve this y_true and y_pred should first be casted to a dtype big enough so overflows will not occur with reasonable errors, such as float64.
For example:

def mse(image_1:np.ndarray, image_2:np.ndarray) -> float:
    return (np.square(image_1.astype(np.float64)-image_2.astype(np.float64))).mean()

Steps/Code to Reproduce

import numpy as np
from sklearn.metrics import mean_squared_error

true = np.array([0], dtype=np.uint8)
pred = np.array([16], dtype=np.uint8)
mean_squared_error(true, pred)

Expected Results

Expected result is 256 as (0 - 16)**2 = 256

Actual Results

The result of mean_squared_error is 0

Versions

System:
    python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
executable: /usr/bin/python3
   machine: Linux-6.1.58+-x86_64-with-glibc2.35

Python dependencies:
      sklearn: 1.4.2
          pip: 23.1.2
   setuptools: 67.7.2
        numpy: 1.25.2
        scipy: 1.11.4
       Cython: 3.0.10
       pandas: 2.0.3
   matplotlib: 3.7.1
       joblib: 1.4.0
threadpoolctl: 3.4.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-5007b62f.3.23.dev.so
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 2
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: Haswell
@TheGuy42 TheGuy42 added Bug Needs Triage Issue requires triage labels Apr 13, 2024
@adrinjalali
Copy link
Member

One could argue that about any dtype, since overflow can happen in any of them. So I think it's the user who needs to know what dtype makes sense for such an operation. Otherwise we'd need to upcast every input in every method to the highest available dtype / precision, and we can't do that.

cc @ogrisel @betatim since this is similar to the case we had for the array API.

Closing for now, happy to have it reopened if we think we should do something about it.

@adrinjalali adrinjalali closed this as not planned Won't fix, can't repro, duplicate, stale Apr 15, 2024
@TheGuy42
Copy link
Author

Obviously it's your choice to make, but I want to highlight a few points:

  1. The MSE is one of the most used functions in the AI,ML,Signal Processing and many other fields! Since many people may only use this library for its mean_squred_error implementation, it is important that it behaves as expected.
  2. It can be very hard to say if the returned MSE is even in the right range making it very unlikely for someone to know if the result he got is wrong.
  3. Many people rely on these results for their research or work!! If a Random Forest is not working as good as it can because of an overflow thats fine because the user gets what he sees. But if someone gets a wrong result for an MSE (or other metrics) he will be lead to believe something wrong based on it, which gets worse since the error can be arbitrarily far from the real answer.
  4. Many people now use or research quantization, in which case, using small dtypes for the calculations is the whole point (and knowingly getting overflows), but even in this case you still need the right answers for the metrics and losses.

I also want to highlight that for integer types the case is much worse as:
5. There is an open issue in numpy that is related which refers to the fact that there is no way to get warning of overflows for integer dtypes on numpy (it is possible for floats with np.seterr).
6. Many images will load with np.uint8 dtype, which means that errors are common for data loaded from images.

In my opinion, as can be argued by points 3 and 4, it is reasonable to only upcast to floating-point dtype in calculations of metrics like mean_squared_error. This is the case in the implementation of mean_squared_error and other metrics in scikit-image.
And even if you prefer to keep the implementation as is, I think it should at least be noted in the documentation of mean_squared_error so people are aware of the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Needs Triage Issue requires triage
Projects
None yet
Development

No branches or pull requests

2 participants