diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 65503a0674a70..44a4af6397642 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -619,11 +619,10 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)) skip_remove_kwargs = dict(remove_none=False, remove_types=[]) xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs) - arrays_device = device(*array_list, **skip_remove_kwargs) if is_array_api: - return xp, is_array_api, arrays_device + return xp, is_array_api, device(*array_list, **skip_remove_kwargs) else: - return xp, False, arrays_device + return xp, False, None def _expit(X, xp=None):