Skip to content

MAINT/BUG: improve gradient dtype handling #9411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,33 +1714,27 @@ def gradient(f, *varargs, **kwargs):
slice3 = [slice(None)]*N
slice4 = [slice(None)]*N

otype = f.dtype.char
if otype not in ['f', 'd', 'F', 'D', 'm', 'M']:
otype = 'd'

# Difference of datetime64 elements results in timedelta64
if otype == 'M':
# Need to use the full dtype name because it contains unit information
otype = f.dtype.name.replace('datetime', 'timedelta')
elif otype == 'm':
# Needs to keep the specific units, can't be a general unit
otype = f.dtype

# Convert datetime64 data into ints. Make dummy variable `y`
# that is a view of ints if the data is datetime64, otherwise
# just set y equal to the array `f`.
if f.dtype.char in ["M", "m"]:
y = f.view('int64')
otype = f.dtype
if otype.type is np.datetime64:
# the timedelta dtype with the same unit information
otype = np.dtype(otype.name.replace('datetime', 'timedelta'))
# view as timedelta to allow addition
f = f.view(otype)
elif otype.type is np.timedelta64:
pass
elif np.issubdtype(otype, np.inexact):
pass
else:
y = f
# all other types convert to floating point
otype = np.double

for i, axis in enumerate(axes):
if y.shape[axis] < edge_order + 1:
if f.shape[axis] < edge_order + 1:
raise ValueError(
"Shape of array too small to calculate a numerical gradient, "
"at least (edge_order + 1) elements are required.")
# result allocation
out = np.empty_like(y, dtype=otype)
out = np.empty_like(f, dtype=otype)

uniform_spacing = np.isscalar(dx[i])

Expand Down Expand Up @@ -1771,15 +1765,15 @@ def gradient(f, *varargs, **kwargs):
slice2[axis] = 1
slice3[axis] = 0
dx_0 = dx[i] if uniform_spacing else dx[i][0]
# 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0])
out[slice1] = (y[slice2] - y[slice3]) / dx_0
# 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
out[slice1] = (f[slice2] - f[slice3]) / dx_0

slice1[axis] = -1
slice2[axis] = -1
slice3[axis] = -2
dx_n = dx[i] if uniform_spacing else dx[i][-1]
# 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2])
out[slice1] = (y[slice2] - y[slice3]) / dx_n
# 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2])
out[slice1] = (f[slice2] - f[slice3]) / dx_n

# Numerical differentiation: 2nd order edges
else:
Expand All @@ -1797,8 +1791,8 @@ def gradient(f, *varargs, **kwargs):
a = -(2. * dx1 + dx2)/(dx1 * (dx1 + dx2))
b = (dx1 + dx2) / (dx1 * dx2)
c = - dx1 / (dx2 * (dx1 + dx2))
# 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2]
out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4]
# 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2]
out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4]

slice1[axis] = -1
slice2[axis] = -3
Expand All @@ -1815,7 +1809,7 @@ def gradient(f, *varargs, **kwargs):
b = - (dx2 + dx1) / (dx1 * dx2)
c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2))
# 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4]
out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4]

outvals.append(out)

Expand Down
6 changes: 6 additions & 0 deletions numpy/lib/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,12 @@ def test_timedelta64(self):
assert_array_equal(gradient(x), dx)
assert_(dx.dtype == np.dtype('timedelta64[D]'))

def test_inexact_dtypes(self):
for dt in [np.float16, np.float32, np.float64]:
# dtypes should not be promoted in a different way to what diff does
x = np.array([1, 2, 3], dtype=dt)
assert_equal(gradient(x).dtype, np.diff(x).dtype)

def test_values(self):
# needs at least 2 points for edge_order ==1
gradient(np.arange(2), edge_order=1)
Expand Down