Description
It would be great to get the scikit-learn Array API back-end to be compatible with MLX (which is mostly conformant with the array API).
Here is an example which currently does not work for a few reasons:
from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx
X_np, y_np = make_classification(random_state=0)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)
with config_context(array_api_dispatch=True):
lda = LinearDiscriminantAnalysis()
X_trans = lda.fit_transform(X_mx, y_mx)
print(type(X_trans))
The reasons it does not work:
-
MLX does not have a
float64
data type (similar to PyTorch MPS backend). It's a bit hacky to setmx.float64 = mx.float32
so maybe good to handle this in the scikit or in a compatibility layer. -
MLX does not support operations with data-dependent output shapes, e.g.
unique_values
. Since these are optional in the array API should we attempt to avoid using them in scikit to get maximal compatibility with other frameworks? -
There are still a couple functions missing in MLX like
mx.asarray
andmx.isdtype
(those are pretty easy for us to add)
Relevant discussion in MLX ml-explore/mlx#1289
CC @betatim