Skip to content

Commit 9760a84

Browse files
committed
ENH: allow list/tuple pad_width in pad
1 parent a96dffb commit 9760a84

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

src/array_api_extra/_funcs.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555

556556
def pad(
557557
x: Array,
558-
pad_width: int,
558+
pad_width: int | tuple | list,
559559
mode: str = "constant",
560560
*,
561561
xp: ModuleType | None = None,
@@ -568,8 +568,12 @@ def pad(
568568
----------
569569
x : array
570570
Input array.
571-
pad_width : int
571+
pad_width : int or tuple of ints or list of pairs of ints
572572
Pad the input array with this many elements from each side.
573+
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574+
each pair applies to the corresponding axis of ``x``.
575+
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576+
copies of this tuple.
573577
mode : str, optional
574578
Only "constant" mode is currently supported, which pads with
575579
the value passed to `constant_values`.
@@ -590,16 +594,45 @@ def pad(
590594

591595
value = constant_values
592596

597+
# make pad_width a list of length-2 tuples of ints
598+
if isinstance(pad_width, int):
599+
pad_width = [(pad_width, pad_width)] * x.ndim
600+
601+
if isinstance(pad_width, tuple):
602+
pad_width = [pad_width] * x.ndim
603+
593604
if xp is None:
594605
xp = array_namespace(x)
595606

607+
slices = []
608+
newshape = []
609+
for ax, w_tpl in enumerate(pad_width):
610+
if len(w_tpl) != 2:
611+
raise ValueError(f"expect a 2-tuple (before, after), got {w_tpl}.")
612+
613+
sh = x.shape[ax]
614+
if w_tpl[0] == 0 and w_tpl[1] == 0:
615+
sl = slice(None, None, None)
616+
else:
617+
start, stop = w_tpl
618+
if stop == 0:
619+
stop = None
620+
else:
621+
stop = -stop
622+
623+
sl = slice(start, stop, None)
624+
sh += w_tpl[0] + w_tpl[1]
625+
626+
newshape.append(sh)
627+
slices.append(sl)
628+
596629
padded = xp.full(
597-
tuple(x + 2 * pad_width for x in x.shape),
630+
tuple(newshape),
598631
fill_value=value,
599632
dtype=x.dtype,
600633
device=_compat.device(x),
601634
)
602-
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
635+
padded[tuple(slices)] = x
603636
return padded
604637

605638

tests/test_funcs.py

+16
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,19 @@ def test_device(self):
416416

417417
def test_xp(self):
418418
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
419+
420+
def test_tuple_width(self):
421+
a = xp.reshape(xp.arange(12), (3, 4))
422+
padded = pad(a, (1, 0))
423+
assert padded.shape == (4, 5)
424+
425+
padded = pad(a, (1, 2))
426+
assert padded.shape == (6, 7)
427+
428+
def test_list_of_tuples_width(self):
429+
a = xp.reshape(xp.arange(12), (3, 4))
430+
padded = pad(a, [(1, 0), (0, 2)])
431+
assert padded.shape == (4, 6)
432+
433+
padded = pad(a, [(1, 0), (0, 0)])
434+
assert padded.shape == (4, 4)

0 commit comments

Comments
 (0)