Skip to content

Commit 0fc081a

Browse files
Fix _fill_or_add_to_diagonal when reshape returns copy (scikit-learn#31445)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent cc526ee commit 0fc081a

File tree

4 files changed

+167
-41
lines changed

4 files changed

+167
-41
lines changed

sklearn/decomposition/_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy import linalg
1010

1111
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
12-
from ..utils._array_api import _fill_or_add_to_diagonal, device, get_namespace
12+
from ..utils._array_api import _add_to_diagonal, device, get_namespace
1313
from ..utils.validation import check_is_fitted, validate_data
1414

1515

@@ -47,7 +47,7 @@ def get_covariance(self):
4747
xp.asarray(0.0, device=device(exp_var), dtype=exp_var.dtype),
4848
)
4949
cov = (components_.T * exp_var_diff) @ components_
50-
_fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
50+
_add_to_diagonal(cov, self.noise_variance_, xp)
5151
return cov
5252

5353
def get_precision(self):
@@ -89,10 +89,10 @@ def get_precision(self):
8989
xp.asarray(0.0, device=device(exp_var)),
9090
)
9191
precision = components_ @ components_.T / self.noise_variance_
92-
_fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
92+
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
9393
precision = components_.T @ linalg_inv(precision) @ components_
9494
precision /= -(self.noise_variance_**2)
95-
_fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
95+
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
9696
return precision
9797

9898
@abstractmethod

sklearn/metrics/pairwise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..preprocessing import normalize
2020
from ..utils import check_array, gen_batches, gen_even_slices
2121
from ..utils._array_api import (
22-
_fill_or_add_to_diagonal,
22+
_fill_diagonal,
2323
_find_matching_floating_dtype,
2424
_is_numpy_namespace,
2525
_max_precision_float_dtype,
@@ -439,7 +439,7 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
439439
# Ensure that distances between vectors and themselves are set to 0.0.
440440
# This may not be the case due to floating point rounding errors.
441441
if X is Y:
442-
distances = _fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)
442+
_fill_diagonal(distances, 0, xp=xp)
443443

444444
if squared:
445445
return distances
@@ -1177,7 +1177,7 @@ def cosine_distances(X, Y=None):
11771177
if X is Y or Y is None:
11781178
# Ensure that distances between vectors and themselves are set to 0.0.
11791179
# This may not be the case due to floating point rounding errors.
1180-
S = _fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
1180+
_fill_diagonal(S, 0.0, xp)
11811181
return S
11821182

11831183

@@ -1982,7 +1982,7 @@ def _parallel_pairwise(X, Y, func, n_jobs, **kwds):
19821982
if (X is Y or Y is None) and func is euclidean_distances:
19831983
# zeroing diagonal for euclidean norm.
19841984
# TODO: do it also for other norms.
1985-
ret = _fill_or_add_to_diagonal(ret, 0, xp=xp, add_value=False)
1985+
_fill_diagonal(ret, 0, xp=xp)
19861986

19871987
# Transform output back
19881988
return ret.T

sklearn/utils/_array_api.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -527,40 +527,80 @@ def _expit(X, xp=None):
527527
return 1.0 / (1.0 + xp.exp(-X))
528528

529529

530-
def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
531-
"""Implementation to facilitate adding or assigning specified values to the
532-
diagonal of a 2-d array.
533-
534-
If ``add_value`` is `True` then the values will be added to the diagonal
535-
elements otherwise the values will be assigned to the diagonal elements.
536-
By default, ``add_value`` is set to `True. This is currently only
537-
supported for 2-d arrays.
538-
539-
The implementation is taken from the `numpy.fill_diagonal` function:
540-
https://github.com/numpy/numpy/blob/v2.0.0/numpy/lib/_index_tricks_impl.py#L799-L929
541-
"""
530+
def _validate_diagonal_args(array, value, xp):
531+
"""Validate arguments to `_fill_diagonal`/`_add_to_diagonal`."""
542532
if array.ndim != 2:
543533
raise ValueError(
544-
f"array should be 2-d. Got array with shape {tuple(array.shape)}"
534+
f"`array` should be 2D. Got array with shape {tuple(array.shape)}"
545535
)
546536

547537
value = xp.asarray(value, dtype=array.dtype, device=device(array))
548-
end = None
549-
# Explicit, fast formula for the common case. For 2-d arrays, we
550-
# accept rectangular ones.
551-
step = array.shape[1] + 1
552-
if not wrap:
553-
end = array.shape[1] * array.shape[1]
538+
if value.ndim not in [0, 1]:
539+
raise ValueError(
540+
"`value` needs to be a scalar or a 1D array, "
541+
f"got a {value.ndim}D array instead."
542+
)
543+
min_rows_columns = min(array.shape)
544+
if value.ndim == 1 and value.shape[0] != min_rows_columns:
545+
raise ValueError(
546+
"`value` needs to be a scalar or 1D array of the same length as the "
547+
f"diagonal of `array` ({min_rows_columns}). Got {value.shape[0]}"
548+
)
549+
550+
return value, min_rows_columns
551+
552+
553+
def _fill_diagonal(array, value, xp):
554+
"""Minimal implementation of `numpy.fill_diagonal`.
555+
556+
`wrap` is not supported (i.e. always False). `value` should be a scalar or
557+
1D of greater or equal length as the diagonal (i.e., `value` is never repeated
558+
when shorter).
559+
560+
Note `array` is altered in place.
561+
"""
562+
value, min_rows_columns = _validate_diagonal_args(array, value, xp)
554563

555-
array_flat = xp.reshape(array, (-1,))
556-
if add_value:
557-
array_flat[:end:step] += value
564+
if _is_numpy_namespace(xp):
565+
xp.fill_diagonal(array, value, wrap=False)
558566
else:
559-
array_flat[:end:step] = value
560-
# `array_flat` is not always a view on `array` (e.g. for certain array types that
561-
# were filled via parallel processing i.e., in `_parallel_pairwise`), thus we need
562-
# to return reshaped `array_flat`.
563-
return xp.reshape(array_flat, array.shape)
567+
# TODO: when array libraries support `reshape(copy)`, use
568+
# `reshape(array, (-1,), copy=False)`, then fill with `[:end:step]` (within
569+
# `try/except`). This is faster than for loop, when no copy needs to be
570+
# made within `reshape`. See #31445 for details.
571+
if value.ndim == 0:
572+
for i in range(min_rows_columns):
573+
array[i, i] = value
574+
else:
575+
for i in range(min_rows_columns):
576+
array[i, i] = value[i]
577+
578+
579+
def _add_to_diagonal(array, value, xp):
580+
"""Add `value` to diagonal of `array`.
581+
582+
Related to `fill_diagonal`. `value` should be a scalar or
583+
1D of greater or equal length as the diagonal (i.e., `value` is never repeated
584+
when shorter).
585+
586+
Note `array` is altered in place.
587+
"""
588+
value, min_rows_columns = _validate_diagonal_args(array, value, xp)
589+
590+
if _is_numpy_namespace(xp):
591+
step = array.shape[1] + 1
592+
# Ensure we do not wrap
593+
end = array.shape[1] * array.shape[1]
594+
array.flat[:end:step] += value
595+
return
596+
597+
# TODO: when array libraries support `reshape(copy)`, use
598+
# `reshape(array, (-1,), copy=False)`, then fill with `[:end:step]` (within
599+
# `try/except`). This is faster than for loop, when no copy needs to be
600+
# made within `reshape`. See #31445 for details.
601+
value = xp.linalg.diagonal(array) + value
602+
for i in range(min_rows_columns):
603+
array[i, i] = value[i]
564604

565605

566606
def _is_xp_namespace(xp, name):

sklearn/utils/tests/test_array_api.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from sklearn._config import config_context
1010
from sklearn.base import BaseEstimator
1111
from sklearn.utils._array_api import (
12+
_add_to_diagonal,
1213
_asarray_with_order,
1314
_atol_for_type,
1415
_average,
1516
_convert_to_numpy,
1617
_count_nonzero,
1718
_estimator_with_converted_arrays,
18-
_fill_or_add_to_diagonal,
19+
_fill_diagonal,
1920
_get_namespace_device_dtype_ids,
2021
_is_numpy_namespace,
2122
_isin,
@@ -26,6 +27,7 @@
2627
_nanmean,
2728
_nanmin,
2829
_ravel,
30+
_validate_diagonal_args,
2931
device,
3032
get_namespace,
3133
get_namespace_and_device,
@@ -576,21 +578,105 @@ def test_count_nonzero(
576578
assert device(array_xp) == device(result)
577579

578580

581+
@pytest.mark.parametrize(
582+
"array, value, match",
583+
[
584+
(numpy.array([1, 2, 3]), 1, "`array` should be 2D"),
585+
(numpy.array([[1, 2], [3, 4]]), numpy.array([1, 2, 3]), "`value` needs to be"),
586+
(numpy.array([[1, 2], [3, 4]]), [1, 2, 3], "`value` needs to be"),
587+
(
588+
numpy.array([[1, 2], [3, 4]]),
589+
numpy.array([[1, 2], [3, 4]]),
590+
"`value` needs to be a",
591+
),
592+
],
593+
)
594+
def test_validate_diagonal_args(array, value, match):
595+
"""Check `_validate_diagonal_args` raises the correct errors."""
596+
xp = _array_api_for_tests("numpy", None)
597+
with pytest.raises(ValueError, match=match):
598+
_validate_diagonal_args(array, value, xp)
599+
600+
601+
@pytest.mark.parametrize("function", ["fill", "add"])
602+
@pytest.mark.parametrize("c_contiguity", [True, False])
603+
def test_fill_and_add_to_diagonal(c_contiguity, function):
604+
"""Check `_fill/add_to_diagonal` behaviour correct with numpy arrays."""
605+
xp = _array_api_for_tests("numpy", None)
606+
if c_contiguity:
607+
array = numpy.zeros((3, 4))
608+
else:
609+
array = numpy.zeros((3, 4)).T
610+
assert array.flags["C_CONTIGUOUS"] == c_contiguity
611+
612+
if function == "fill":
613+
func = _fill_diagonal
614+
else:
615+
func = _add_to_diagonal
616+
617+
func(array, 1, xp)
618+
assert_allclose(array.diagonal(), numpy.ones((3,)))
619+
620+
func(array, [0, 1, 2], xp)
621+
if function == "fill":
622+
expected_diag = numpy.arange(3)
623+
else:
624+
expected_diag = numpy.ones((3,)) + numpy.arange(3)
625+
assert_allclose(array.diagonal(), expected_diag)
626+
627+
fill_array = numpy.array([11, 12, 13])
628+
func(array, fill_array, xp)
629+
if function == "fill":
630+
expected_diag = fill_array
631+
else:
632+
expected_diag = fill_array + numpy.arange(3) + numpy.ones((3,))
633+
assert_allclose(array.diagonal(), expected_diag)
634+
635+
636+
@pytest.mark.parametrize("array", ["standard", "transposed", "non-contiguous"])
637+
@pytest.mark.parametrize(
638+
"array_namespace, device_, dtype_name",
639+
yield_namespace_device_dtype_combinations(),
640+
ids=_get_namespace_device_dtype_ids,
641+
)
642+
def test_fill_diagonal(array, array_namespace, device_, dtype_name):
643+
"""Check array API `_fill_diagonal` consistent with `numpy._fill_diagonal`."""
644+
xp = _array_api_for_tests(array_namespace, device_)
645+
array_np = numpy.zeros((4, 5), dtype=dtype_name)
646+
647+
if array == "transposed":
648+
array_xp = xp.asarray(array_np.copy(), device=device_).T
649+
array_np = array_np.T
650+
elif array == "non-contiguous":
651+
array_xp = xp.asarray(array_np.copy(), device=device_)[::2, ::2]
652+
array_np = array_np[::2, ::2]
653+
else:
654+
array_xp = xp.asarray(array_np.copy(), device=device_)
655+
656+
numpy.fill_diagonal(array_np, val=1)
657+
with config_context(array_api_dispatch=True):
658+
_fill_diagonal(array_xp, value=1, xp=xp)
659+
660+
assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)
661+
662+
579663
@pytest.mark.parametrize(
580664
"array_namespace, device_, dtype_name",
581665
yield_namespace_device_dtype_combinations(),
582666
ids=_get_namespace_device_dtype_ids,
583667
)
584-
@pytest.mark.parametrize("wrap", [True, False])
585-
def test_fill_or_add_to_diagonal(array_namespace, device_, dtype_name, wrap):
668+
def test_add_to_diagonal(array_namespace, device_, dtype_name):
669+
"""Check `_add_to_diagonal` consistent between array API xp and numpy namespace."""
586670
xp = _array_api_for_tests(array_namespace, device_)
671+
np_xp = _array_api_for_tests("numpy", None)
587672

588-
array_np = numpy.zeros((5, 4), dtype=dtype_name)
673+
array_np = numpy.zeros((3, 4), dtype=dtype_name)
589674
array_xp = xp.asarray(array_np.copy(), device=device_)
590675

591-
numpy.fill_diagonal(array_np, val=1, wrap=wrap)
676+
add_val = [1, 2, 3]
677+
_fill_diagonal(array_np, value=add_val, xp=np_xp)
592678
with config_context(array_api_dispatch=True):
593-
_fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap)
679+
_fill_diagonal(array_xp, value=add_val, xp=xp)
594680

595681
assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)
596682

0 commit comments

Comments
 (0)