Skip to content

Array API backends support for MLX #29673

Open
@awni

Description

@awni

It would be great to get the scikit-learn Array API back-end to be compatible with MLX (which is mostly conformant with the array API).

Here is an example which currently does not work for a few reasons:

from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx

X_np, y_np = make_classification(random_state=0)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)

with config_context(array_api_dispatch=True):
    lda = LinearDiscriminantAnalysis()
    X_trans = lda.fit_transform(X_mx, y_mx)

print(type(X_trans))

The reasons it does not work:

  • MLX does not have a float64 data type (similar to PyTorch MPS backend). It's a bit hacky to set mx.float64 = mx.float32 so maybe good to handle this in the scikit or in a compatibility layer.

  • MLX does not support operations with data-dependent output shapes, e.g. unique_values. Since these are optional in the array API should we attempt to avoid using them in scikit to get maximal compatibility with other frameworks?

  • There are still a couple functions missing in MLX like mx.asarray and mx.isdtype (those are pretty easy for us to add)

Relevant discussion in MLX ml-explore/mlx#1289

CC @betatim

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions