From a05521573eaab8ea6dfdb763db36d8a34e8f62de Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 20 Jan 2025 20:02:30 +0000 Subject: [PATCH] pad_width --- src/array_api_extra/_delegation.py | 7 ++++--- src/array_api_extra/_lib/_funcs.py | 18 +++++++++++++----- tests/test_funcs.py | 8 ++++---- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index c3d77f8e..195dd88f 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1,5 +1,6 @@ """Delegation to existing implementations for Public API Functions.""" +from collections.abc import Sequence from types import ModuleType from typing import Literal @@ -31,7 +32,7 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool: def pad( x: Array, - pad_width: int | tuple[int, int] | list[tuple[int, int]], + pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], mode: Literal["constant"] = "constant", *, constant_values: bool | int | float | complex = 0, @@ -44,9 +45,9 @@ def pad( ---------- x : array Input array. - pad_width : int or tuple of ints or list of pairs of ints + pad_width : int or tuple of ints or sequence of pairs of ints Pad the input array with this many elements from each side. - If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, + If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, each pair applies to the corresponding axis of ``x``. A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` copies of this tuple. diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 85778356..c1c39f58 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -5,6 +5,7 @@ import math import warnings +from collections.abc import Sequence from types import ModuleType from typing import cast @@ -448,7 +449,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: def pad( x: Array, - pad_width: int | tuple[int, int] | list[tuple[int, int]], + pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], *, constant_values: bool | int | float | complex = 0, xp: ModuleType, @@ -456,15 +457,22 @@ def pad( """See docstring in `array_api_extra._delegation.py`.""" # make pad_width a list of length-2 tuples of ints x_ndim = cast(int, x.ndim) + if isinstance(pad_width, int): - pad_width = [(pad_width, pad_width)] * x_ndim - if isinstance(pad_width, tuple): - pad_width = [pad_width] * x_ndim + pad_width_seq = [(pad_width, pad_width)] * x_ndim + elif ( + isinstance(pad_width, tuple) + and len(pad_width) == 2 + and all(isinstance(i, int) for i in pad_width) + ): + pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim + else: + pad_width_seq = cast(list[tuple[int, int]], list(pad_width)) # https://github.com/python/typeshed/issues/13376 slices: list[slice] = [] # type: ignore[no-any-explicit] newshape: list[int] = [] - for ax, w_tpl in enumerate(pad_width): + for ax, w_tpl in enumerate(pad_width_seq): if len(w_tpl) != 2: msg = f"expect a 2-tuple (before, after), got {w_tpl}." raise ValueError(msg) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 4847cb9f..9be01508 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -390,12 +390,12 @@ def test_tuple_width(self, xp: ModuleType): with pytest.raises((ValueError, RuntimeError)): pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] - def test_list_of_tuples_width(self, xp: ModuleType): + def test_sequence_of_tuples_width(self, xp: ModuleType): a = xp.reshape(xp.arange(12), (3, 4)) - padded = pad(a, [(1, 0), (0, 2)]) - assert padded.shape == (4, 6) - padded = pad(a, [(1, 0), (0, 0)]) + padded = pad(a, ((1, 0), (0, 2))) + assert padded.shape == (4, 6) + padded = pad(a, ((1, 0), (0, 0))) assert padded.shape == (4, 4)