Skip to content

Commit 433ac81

Browse files
Modified Bessel functions of order zero and one.
The functions are tf.math.bessel_i0(x), tf.math.bessel_i0e(x), tf.math.bessel_i1(x) and tf.math.bessel_i1e(x). The exponentially scaled versions tf.math.bessel_i0e(x) and tf.math.bessel_i1e(x) are more numerically stable. This code wraps the implementation that was recently added to Eigen. PiperOrigin-RevId: 200186968
1 parent da88bfa commit 433ac81

File tree

14 files changed

+300
-0
lines changed

14 files changed

+300
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
op {
2+
graph_op_name: "BesselI0e"
3+
summary: "Computes the Bessel i0e function of `x` element-wise."
4+
description: <<END
5+
Exponentially scaled modified Bessel function of order 0 defined as
6+
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
7+
8+
This function is faster and numerically stabler than `bessel_i0(x)`.
9+
END
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
op {
2+
graph_op_name: "BesselI1e"
3+
summary: "Computes the Bessel i1e function of `x` element-wise."
4+
description: <<END
5+
Exponentially scaled modified Bessel function of order 0 defined as
6+
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
7+
8+
This function is faster and numerically stabler than `bessel_i1(x)`.
9+
END
10+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "BesselI0e"
3+
visibility: HIDDEN
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "BesselI1e"
3+
visibility: HIDDEN
4+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/kernels/cwise_ops_common.h"
17+
18+
namespace tensorflow {
19+
REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
20+
double);
21+
REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
22+
double);
23+
#if GOOGLE_CUDA
24+
REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
25+
double);
26+
REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
27+
double);
28+
#endif
29+
} // namespace tensorflow
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if GOOGLE_CUDA
17+
18+
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
19+
20+
namespace tensorflow {
21+
namespace functor {
22+
DEFINE_UNARY3(bessel_i0e, Eigen::half, float, double);
23+
DEFINE_UNARY3(bessel_i1e, Eigen::half, float, double);
24+
} // namespace functor
25+
} // namespace tensorflow
26+
27+
#endif // GOOGLE_CUDA

tensorflow/core/kernels/cwise_ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,12 @@ struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
616616
template <typename T>
617617
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
618618

619+
template <typename T>
620+
struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {};
621+
622+
template <typename T>
623+
struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {};
624+
619625
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
620626
};
621627

tensorflow/core/ops/math_ops.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ REGISTER_OP("Acos").UNARY();
239239

240240
REGISTER_OP("Atan").UNARY();
241241

242+
REGISTER_OP("BesselI0e").UNARY_REAL();
243+
244+
REGISTER_OP("BesselI1e").UNARY_REAL();
245+
242246
#undef UNARY
243247
#undef UNARY_REAL
244248
#undef UNARY_COMPLEX

tensorflow/python/kernel_tests/cwise_ops_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def testFloatBasic(self):
241241
math_ops.lgamma)
242242
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
243243
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
244+
try:
245+
from scipy import special # pylint: disable=g-import-not-at-top
246+
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
247+
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
248+
except ImportError as e:
249+
tf_logging.warn("Cannot test special functions: %s" % str(e))
244250

245251
self._compareBothSparse(x, np.abs, math_ops.abs)
246252
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -286,6 +292,12 @@ def testFloatEmpty(self):
286292
self._compareBoth(x, np.arcsin, math_ops.asin)
287293
self._compareBoth(x, np.arccos, math_ops.acos)
288294
self._compareBoth(x, np.arctan, math_ops.atan)
295+
try:
296+
from scipy import special # pylint: disable=g-import-not-at-top
297+
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
298+
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
299+
except ImportError as e:
300+
tf_logging.warn("Cannot test special functions: %s" % str(e))
289301

290302
self._compareBothSparse(x, np.abs, math_ops.abs)
291303
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -334,6 +346,12 @@ def testDoubleBasic(self):
334346
self._compareBoth(k, np.arcsin, math_ops.asin)
335347
self._compareBoth(k, np.arccos, math_ops.acos)
336348
self._compareBoth(k, np.tan, math_ops.tan)
349+
try:
350+
from scipy import special # pylint: disable=g-import-not-at-top
351+
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
352+
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
353+
except ImportError as e:
354+
tf_logging.warn("Cannot test special functions: %s" % str(e))
337355

338356
self._compareBothSparse(x, np.abs, math_ops.abs)
339357
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -370,6 +388,12 @@ def testHalfBasic(self):
370388
math_ops.lgamma)
371389
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
372390
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
391+
try:
392+
from scipy import special # pylint: disable=g-import-not-at-top
393+
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
394+
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
395+
except ImportError as e:
396+
tf_logging.warn("Cannot test special functions: %s" % str(e))
373397

374398
self._compareBothSparse(x, np.abs, math_ops.abs)
375399
self._compareBothSparse(x, np.negative, math_ops.negative)

tensorflow/python/ops/math_grad.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,35 @@ def _DigammaGrad(op, grad):
620620
return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
621621

622622

623+
@ops.RegisterGradient("BesselI0e")
624+
def _BesselI0eGrad(op, grad):
625+
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
626+
x = op.inputs[0]
627+
y = op.outputs[0]
628+
with ops.control_dependencies([grad]):
629+
return grad * (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
630+
631+
632+
@ops.RegisterGradient("BesselI1e")
633+
def _BesselI1eGrad(op, grad):
634+
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
635+
x = op.inputs[0]
636+
y = op.outputs[0]
637+
with ops.control_dependencies([grad]):
638+
# For x = 0, the correct gradient is 0.5.
639+
# However, the main branch gives NaN because of the division by x, so
640+
# we impute the gradient manually.
641+
# An alternative solution is to express the gradient via bessel_i0e and
642+
# bessel_i2e, but the latter is not yet implemented in Eigen.
643+
eps = np.finfo(x.dtype.as_numpy_dtype).eps
644+
zeros = array_ops.zeros_like(x)
645+
x_is_not_tiny = math_ops.abs(x) > eps
646+
safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
647+
dy_dx = math_ops.bessel_i0e(safe_x) - y * (
648+
math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
649+
return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
650+
651+
623652
@ops.RegisterGradient("Igamma")
624653
def _IgammaGrad(op, grad):
625654
"""Returns gradient of igamma(a, x) with respect to x."""

tensorflow/python/ops/math_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,6 +2954,67 @@ def polyval(coeffs, x, name=None):
29542954
p = c + p * x
29552955
return p
29562956

2957+
2958+
@tf_export("math.bessel_i0e")
2959+
def bessel_i0e(x, name=None):
2960+
"""Computes the Bessel i0e function of `x` element-wise.
2961+
2962+
Exponentially scaled modified Bessel function of order 0 defined as
2963+
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
2964+
2965+
This function is faster and numerically stabler than `bessel_i0(x)`.
2966+
2967+
Args:
2968+
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
2969+
`float32`, `float64`.
2970+
name: A name for the operation (optional).
2971+
2972+
Returns:
2973+
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
2974+
2975+
@compatibility(scipy)
2976+
Equivalent to scipy.special.i0e
2977+
@end_compatibility
2978+
"""
2979+
with ops.name_scope(name, "bessel_i0e", [x]) as name:
2980+
if isinstance(x, sparse_tensor.SparseTensor):
2981+
x_i0e = gen_math_ops.bessel_i0e(x.values, name=name)
2982+
return sparse_tensor.SparseTensor(
2983+
indices=x.indices, values=x_i0e, dense_shape=x.dense_shape)
2984+
else:
2985+
return gen_math_ops.bessel_i0e(x, name=name)
2986+
2987+
2988+
@tf_export("math.bessel_i1e")
2989+
def bessel_i1e(x, name=None):
2990+
"""Computes the Bessel i1e function of `x` element-wise.
2991+
2992+
Exponentially scaled modified Bessel function of order 1 defined as
2993+
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
2994+
2995+
This function is faster and numerically stabler than `bessel_i1(x)`.
2996+
2997+
Args:
2998+
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
2999+
`float32`, `float64`.
3000+
name: A name for the operation (optional).
3001+
3002+
Returns:
3003+
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
3004+
3005+
@compatibility(scipy)
3006+
Equivalent to scipy.special.i1e
3007+
@end_compatibility
3008+
"""
3009+
with ops.name_scope(name, "bessel_i1e", [x]) as name:
3010+
if isinstance(x, sparse_tensor.SparseTensor):
3011+
x_i1e = gen_math_ops.bessel_i1e(x.values, name=name)
3012+
return sparse_tensor.SparseTensor(
3013+
indices=x.indices, values=x_i1e, dense_shape=x.dense_shape)
3014+
else:
3015+
return gen_math_ops.bessel_i1e(x, name=name)
3016+
3017+
29573018
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
29583019
# 1.0 API so we leave these here for backwards compatibility.
29593020
fft = gen_spectral_ops.fft

tensorflow/python/ops/special_math_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,54 @@ def lbeta(x, name='lbeta'):
8282
return result
8383

8484

85+
@tf_export('math.bessel_i0')
86+
def bessel_i0(x, name='bessel_i0'):
87+
"""Computes the Bessel i0 function of `x` element-wise.
88+
89+
Modified Bessel function of order 0.
90+
91+
It is preferable to use the numerically stabler function `i0e(x)` instead.
92+
93+
Args:
94+
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
95+
`float32`, `float64`.
96+
name: A name for the operation (optional).
97+
98+
Returns:
99+
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
100+
101+
@compatibility(scipy)
102+
Equivalent to scipy.special.i0
103+
@end_compatibility
104+
"""
105+
with ops.name_scope(name, [x]):
106+
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
107+
108+
109+
@tf_export('math.bessel_i1')
110+
def bessel_i1(x, name='bessel_i1'):
111+
"""Computes the Bessel i1 function of `x` element-wise.
112+
113+
Modified Bessel function of order 1.
114+
115+
It is preferable to use the numerically stabler function `i1e(x)` instead.
116+
117+
Args:
118+
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
119+
`float32`, `float64`.
120+
name: A name for the operation (optional).
121+
122+
Returns:
123+
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
124+
125+
@compatibility(scipy)
126+
Equivalent to scipy.special.i1
127+
@end_compatibility
128+
"""
129+
with ops.name_scope(name, [x]):
130+
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
131+
132+
85133
@tf_export('einsum', 'linalg.einsum')
86134
def einsum(equation, *inputs, **kwargs):
87135
"""A generalized contraction between tensors of arbitrary dimension.

tensorflow/python/ops/special_math_ops_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorflow.python.ops import math_ops
3030
from tensorflow.python.ops import special_math_ops
3131
from tensorflow.python.platform import test
32+
from tensorflow.python.platform import tf_logging
3233

3334

3435
class LBetaTest(test.TestCase):
@@ -150,6 +151,33 @@ def test_empty_rank2_with_zero_batch_dim_returns_empty(self):
150151
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
151152

152153

154+
class BesselTest(test.TestCase):
155+
156+
def test_bessel_i0(self):
157+
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
158+
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
159+
try:
160+
from scipy import special # pylint: disable=g-import-not-at-top
161+
self.assertAllClose(special.i0(x_single),
162+
self.evaluate(special_math_ops.bessel_i0(x_single)))
163+
self.assertAllClose(special.i0(x_double),
164+
self.evaluate(special_math_ops.bessel_i0(x_double)))
165+
except ImportError as e:
166+
tf_logging.warn('Cannot test special functions: %s' % str(e))
167+
168+
def test_bessel_i1(self):
169+
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
170+
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
171+
try:
172+
from scipy import special # pylint: disable=g-import-not-at-top
173+
self.assertAllClose(special.i1(x_single),
174+
self.evaluate(special_math_ops.bessel_i1(x_single)))
175+
self.assertAllClose(special.i1(x_double),
176+
self.evaluate(special_math_ops.bessel_i1(x_double)))
177+
except ImportError as e:
178+
tf_logging.warn('Cannot test special functions: %s' % str(e))
179+
180+
153181
class EinsumTest(test.TestCase):
154182

155183
simple_cases = [

tensorflow/tools/api/golden/tensorflow.math.pbtxt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
path: "tensorflow.math"
22
tf_module {
3+
member_method {
4+
name: "bessel_i0"
5+
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], "
6+
}
7+
member_method {
8+
name: "bessel_i0e"
9+
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
10+
}
11+
member_method {
12+
name: "bessel_i1"
13+
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i1\'], "
14+
}
15+
member_method {
16+
name: "bessel_i1e"
17+
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
18+
}
319
member_method {
420
name: "polyval"
521
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

0 commit comments

Comments
 (0)