Skip to content

RFC/API (Array API) mixing devices and data types with estimators #26083

@adrinjalali

Description

@adrinjalali

Right now, if the user fits an estimator using a pandas.DataFrame, but passes a numpy.ndarray during predict, they get a warning due to missing feature names.

The situation is only to get more complicated as we're adding support for more types via array API.

Some related issues here are:

  • device: data during fit sits on a GPU, but a CPU is used for predict (with the same data type)
  • types: using one type to fit, and use another type during predict: how do we handle this both in terms of device and the type? Do we let the operator figure out if they can coerce the data into the type which can be used?
  • persistence: how do we let users fit on one device, but load on another device
  • estimator conversion: do we let users convert an estimator which is fit using one type/device, to an estimator compatible with another type/device?

I vaguely remember us talking about some of these issues, but I don't see any active discussion. I might have missed something.

Related: in a library like pytorch, you can decide which device is going to be used when you load a model's weights.

cc @thomasjpfan

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Discussion

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions