Skip to content

Commit 10098da

Browse files
committed
Merge pull request numpy#4454 from jurnix/namedargs
ENH: apply_along_axis accepts named arguments
2 parents c346477 + 7f8aae0 commit 10098da

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

numpy/lib/shape_base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from numpy.core.fromnumeric import product, reshape
1313
from numpy.core import hstack, vstack, atleast_3d
1414

15-
def apply_along_axis(func1d,axis,arr,*args):
15+
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
1616
"""
1717
Apply a function to 1-D slices along the given axis.
1818
@@ -30,6 +30,11 @@ def apply_along_axis(func1d,axis,arr,*args):
3030
Input array.
3131
args : any
3232
Additional arguments to `func1d`.
33+
kwargs: any
34+
Additional named arguments to `func1d`.
35+
36+
.. versionadded:: 1.9.0
37+
3338
3439
Returns
3540
-------
@@ -78,7 +83,7 @@ def apply_along_axis(func1d,axis,arr,*args):
7883
i[axis] = slice(None, None)
7984
outshape = asarray(arr.shape).take(indlist)
8085
i.put(indlist, ind)
81-
res = func1d(arr[tuple(i.tolist())],*args)
86+
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
8287
# if res is a number, then we have a smaller output array
8388
if isscalar(res):
8489
outarr = zeros(outshape, asarray(res).dtype)
@@ -94,7 +99,7 @@ def apply_along_axis(func1d,axis,arr,*args):
9499
ind[n] = 0
95100
n -= 1
96101
i.put(indlist, ind)
97-
res = func1d(arr[tuple(i.tolist())],*args)
102+
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
98103
outarr[tuple(ind)] = res
99104
k += 1
100105
return outarr
@@ -115,7 +120,7 @@ def apply_along_axis(func1d,axis,arr,*args):
115120
ind[n] = 0
116121
n -= 1
117122
i.put(indlist, ind)
118-
res = func1d(arr[tuple(i.tolist())],*args)
123+
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
119124
outarr[tuple(i.tolist())] = res
120125
k += 1
121126
return outarr

numpy/ma/tests/test_extras.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,16 @@ def myfunc(b):
479479
xa = apply_along_axis(myfunc, 2, a)
480480
assert_equal(xa, [[1, 4], [7, 10]])
481481

482+
# Tests kwargs functions
483+
def test_3d_kwargs(self):
484+
a = arange(12).reshape(2, 2, 3)
485+
486+
def myfunc(b, offset=0):
487+
return b[1+offset]
488+
489+
xa = apply_along_axis(myfunc, 2, a, offset=1)
490+
assert_equal(xa, [[2, 5], [8, 11]])
491+
482492

483493
class TestApplyOverAxes(TestCase):
484494
# Tests apply_over_axes

0 commit comments

Comments
 (0)