diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index ebc157fb169d1..29e7b79bb1d65 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -136,6 +136,25 @@ automatically skipped. Therefore it's important to run the tests with the pip install array-api-compat # and other libraries as needed pytest -k "array_api" -v +Note on `float64` support +------------------------- + +`float64` precision can sometimes be mandatory, where `float32` or inferior precision +might result in catastrophic numerical errors. For this reason there are occurrences +in scikit-learn of _upcasting_, where some computational steps will always use `float64` +precision, even if the input data is given with inferior precision. + +However, some array manipulation libraries might leverage devices that do not support +operations with `float64` precision, such as MPS on macOS or IntelĀ® integrated GPUs. +scikit-learn always favors consistency of the numerical stability across all +use-cases, and it will locally dispatch the compute to numpy for steps that have +a minimal precision requirement that is not supported by the device, at the costs of +a transfer to CPU. + +Minimizing the usage of `float64` upcasting in scikit-learn is an open improvement +direction, to maybe yield better performance from devices that do not support it, +since it avoids data copies and benefits from a higher FLOPS. + Note on MPS device support --------------------------