@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555
555
556
556
def pad (
557
557
x : Array ,
558
- pad_width : int ,
558
+ pad_width : int | tuple | list ,
559
559
mode : str = "constant" ,
560
560
* ,
561
561
xp : ModuleType | None = None ,
@@ -568,8 +568,12 @@ def pad(
568
568
----------
569
569
x : array
570
570
Input array.
571
- pad_width : int
571
+ pad_width : int or tuple of ints or list of pairs of ints
572
572
Pad the input array with this many elements from each side.
573
+ If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574
+ each pair applies to the corresponding axis of ``x``.
575
+ A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576
+ copies of this tuple.
573
577
mode : str, optional
574
578
Only "constant" mode is currently supported, which pads with
575
579
the value passed to `constant_values`.
@@ -590,16 +594,45 @@ def pad(
590
594
591
595
value = constant_values
592
596
597
+ # make pad_width a list of length-2 tuples of ints
598
+ if isinstance (pad_width , int ):
599
+ pad_width = [(pad_width , pad_width )] * x .ndim
600
+
601
+ if isinstance (pad_width , tuple ):
602
+ pad_width = [pad_width ] * x .ndim
603
+
593
604
if xp is None :
594
605
xp = array_namespace (x )
595
606
607
+ slices = []
608
+ newshape = []
609
+ for ax , w_tpl in enumerate (pad_width ):
610
+ if len (w_tpl ) != 2 :
611
+ raise ValueError (f"expect a 2-tuple (before, after), got { w_tpl } ." )
612
+
613
+ sh = x .shape [ax ]
614
+ if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
615
+ sl = slice (None , None , None )
616
+ else :
617
+ start , stop = w_tpl
618
+ if stop == 0 :
619
+ stop = None
620
+ else :
621
+ stop = - stop
622
+
623
+ sl = slice (start , stop , None )
624
+ sh += w_tpl [0 ] + w_tpl [1 ]
625
+
626
+ newshape .append (sh )
627
+ slices .append (sl )
628
+
596
629
padded = xp .full (
597
- tuple (x + 2 * pad_width for x in x . shape ),
630
+ tuple (newshape ),
598
631
fill_value = value ,
599
632
dtype = x .dtype ,
600
633
device = _compat .device (x ),
601
634
)
602
- padded [( slice ( pad_width , - pad_width , None ),) * x . ndim ] = x
635
+ padded [tuple ( slices ) ] = x
603
636
return padded
604
637
605
638
0 commit comments