Skip to content

Commit deb845f

Browse files
Added optional argument to specify time step to contrib.integrate.odeint_fixed.
PiperOrigin-RevId: 200220800
1 parent df1f2a0 commit deb845f

File tree

2 files changed

+147
-30
lines changed

2 files changed

+147
-30
lines changed

tensorflow/contrib/integrate/python/ops/odes.py

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensorflow.python.framework import ops
2929
from tensorflow.python.ops import array_ops
3030
from tensorflow.python.ops import control_flow_ops
31-
from tensorflow.python.ops import functional_ops
3231
from tensorflow.python.ops import math_ops
3332
from tensorflow.python.ops import tensor_array_ops
3433

@@ -279,13 +278,27 @@ def _assert_increasing(t):
279278
return ops.control_dependencies([assert_increasing])
280279

281280

282-
def _check_input_types(t, y0):
281+
def _check_input_types(y0, t, dt=None):
283282
if not (y0.dtype.is_floating or y0.dtype.is_complex):
284283
raise TypeError('`y0` must have a floating point or complex floating '
285284
'point dtype')
286285
if not t.dtype.is_floating:
287286
raise TypeError('`t` must have a floating point dtype')
288287

288+
if dt is not None and not dt.dtype.is_floating:
289+
raise TypeError('`dt` must have a floating point dtype')
290+
291+
292+
def _check_input_sizes(t, dt):
293+
if len(t.get_shape().as_list()) > 1:
294+
raise ValueError('t must be a 1D tensor')
295+
296+
if len(dt.get_shape().as_list()) > 1:
297+
raise ValueError('t must be a 1D tensor')
298+
299+
if t.get_shape()[0] != dt.get_shape()[0] + 1:
300+
raise ValueError('t and dt have incompatible lengths, must be N and N-1')
301+
289302

290303
def _dopri5(func,
291304
y0,
@@ -510,7 +523,7 @@ def odeint(func,
510523
# avoiding the need to pack/unpack in user functions.
511524
y0 = ops.convert_to_tensor(y0, name='y0')
512525
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
513-
_check_input_types(t, y0)
526+
_check_input_types(y0, t)
514527

515528
error_dtype = abs(y0).dtype
516529
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
@@ -530,31 +543,82 @@ def odeint(func,
530543
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
531544
"""Base class for fixed-grid ODE integrators."""
532545

533-
def integrate(self, evol_func, y0, time_grid):
534-
time_delta_grid = time_grid[1:] - time_grid[:-1]
535-
536-
scan_func = self._make_scan_func(evol_func)
546+
def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
547+
"""Returns integrated values of differential equation on the `time grid`.
548+
549+
Numerically integrates differential equation defined via time derivative
550+
evaluator `evol_func` using fixed time steps specified in dt_grid.
551+
552+
Args:
553+
evol_func: Callable, evaluates time derivative of y at a given time.
554+
y0: N-D Tensor holds initial values of the solution.
555+
time_grid: 1-D Tensor holding the time points at which the solution
556+
will be recorded, must have a floating dtype.
557+
dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
558+
intervals. Must be a floating dtype and have one less element than that
559+
of the time_grid.
560+
steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
561+
as dt_grid. Specifies number of steps needed for every interval. Assumes
562+
steps_on_intervals * dt_grid == time intervals.
563+
564+
Returns:
565+
(N+1)-D tensor, where the first dimension corresponds to different
566+
time points. Contains the solved value of y for each desired time point in
567+
`t`, with the initial value `y0` being the first element along the first
568+
dimension.
569+
"""
537570

538-
y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
539-
y0)
540-
return array_ops.concat([[y0], y_grid], axis=0)
571+
iteration_func = self._make_iteration_func(evol_func, dt_grid)
572+
integrate_interval = self._make_interval_integrator(iteration_func,
573+
steps_on_intervals)
541574

542-
def _make_scan_func(self, evol_func):
575+
num_times = array_ops.size(time_grid)
576+
current_time = time_grid[0]
577+
solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
578+
solution_array = solution_array.write(0, y0)
543579

544-
def scan_func(y, t_and_dt):
545-
t, dt = t_and_dt
580+
solution_array, _, _, _ = control_flow_ops.while_loop(
581+
lambda _, __, ___, i: i < num_times,
582+
integrate_interval,
583+
(solution_array, y0, current_time, 1)
584+
)
585+
solution_array = solution_array.stack()
586+
solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
587+
return solution_array
588+
589+
def _make_iteration_func(self, evol_func, dt_grid):
590+
"""Returns a function that builds operations of a single time step."""
591+
592+
def iteration_func(y, t, dt_step, interval_step):
593+
"""Performs a single time step advance."""
594+
dt = dt_grid[interval_step - 1]
546595
dy = self._step_func(evol_func, t, dt, y)
547596
dy = math_ops.cast(dy, dtype=y.dtype)
548-
return y + dy
597+
return y + dy, t + dt, dt_step + 1, interval_step
598+
599+
return iteration_func
600+
601+
def _make_interval_integrator(self, iteration_func, interval_sizes):
602+
"""Returns a function that builds operations for interval integration."""
549603

550-
return scan_func
604+
def integrate_interval(solution_array, y, t, interval_num):
605+
"""Integrates y with fixed time step on interval `interval_num`."""
606+
y, t, _, _ = control_flow_ops.while_loop(
607+
lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
608+
iteration_func,
609+
(y, t, 0, interval_num)
610+
)
611+
return solution_array.write(interval_num, y), y, t, interval_num + 1
612+
613+
return integrate_interval
551614

552615
@abc.abstractmethod
553616
def _step_func(self, evol_func, t, dt, y):
554617
pass
555618

556619

557620
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
621+
"""Fixed grid integrator implementing midpoint scheme."""
558622

559623
def _step_func(self, evol_func, t, dt, y):
560624
dt_cast = math_ops.cast(dt, y.dtype)
@@ -563,6 +627,7 @@ def _step_func(self, evol_func, t, dt, y):
563627

564628

565629
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
630+
"""Fixed grid integrator implementing RK4 scheme."""
566631

567632
def _step_func(self, evol_func, t, dt, y):
568633
k1 = evol_func(y, t)
@@ -575,7 +640,7 @@ def _step_func(self, evol_func, t, dt, y):
575640
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
576641

577642

578-
def odeint_fixed(func, y0, t, method='rk4', name=None):
643+
def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
579644
"""ODE integration on a fixed grid (with no step size control).
580645
581646
Useful in certain scenarios to avoid the overhead of adaptive step size
@@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
590655
`y`. The initial time point should be the first element of this sequence,
591656
and each time must be larger than the previous time. May have any floating
592657
point dtype.
658+
dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
659+
integration intervals in `t`. 1-D Tensor should provide values
660+
for all intervals, must have 1 less element than that of `t`.
661+
If given a 0-D Tensor, the value is interpreted as time step suggestion
662+
same for all intervals. If passed None, then time step is set to be the
663+
t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
664+
insuring an integer number of steps per interval, potentially reducing the
665+
time step.
593666
method: One of 'midpoint' or 'rk4'.
594667
name: Optional name for the resulting operation.
595668
@@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
602675
Raises:
603676
ValueError: Upon caller errors.
604677
"""
605-
with ops.name_scope(name, 'odeint_fixed', [y0, t]):
678+
with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
606679
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
607680
y0 = ops.convert_to_tensor(y0, name='y0')
608-
_check_input_types(t, y0)
681+
682+
intervals = t[1:] - t[:-1]
683+
if dt is None:
684+
dt = intervals
685+
dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')
686+
687+
steps_on_intervals = math_ops.ceil(intervals / dt)
688+
dt = intervals / steps_on_intervals
689+
steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)
690+
691+
_check_input_types(y0, t, dt)
692+
_check_input_sizes(t, dt)
609693

610694
with _assert_increasing(t):
611695
with ops.name_scope(method):
612696
if method == 'midpoint':
613-
return _MidpointFixedGridIntegrator().integrate(func, y0, t)
697+
return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
698+
steps_on_intervals)
614699
elif method == 'rk4':
615-
return _RK4FixedGridIntegrator().integrate(func, y0, t)
700+
return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
701+
steps_on_intervals)
616702
else:
617703
raise ValueError('method not supported: {!s}'.format(method))

tensorflow/contrib/integrate/python/ops/odes_test.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,47 +242,78 @@ def test_5th_order_polynomial(self):
242242

243243
class OdeIntFixedTest(test.TestCase):
244244

245-
def _test_integrate_sine(self, method):
245+
def _test_integrate_sine(self, method, t, dt=None):
246246

247247
def evol_func(y, t):
248248
del t
249249
return array_ops.stack([y[1], -y[0]])
250250

251251
y0 = [0., 1.]
252-
time_grid = np.linspace(0., 10., 200)
253-
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
252+
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
254253

255254
with self.test_session() as sess:
256255
y_grid_array = sess.run(y_grid)
257256

258257
np.testing.assert_allclose(
259-
y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
258+
y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2)
260259

261-
def _test_integrate_gaussian(self, method):
260+
def _test_integrate_gaussian(self, method, t, dt=None):
262261

263262
def evol_func(y, t):
264263
return -math_ops.cast(t, dtype=y.dtype) * y[0]
265264

266265
y0 = [1.]
267-
time_grid = np.linspace(0., 2., 100)
268-
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
266+
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
269267

270268
with self.test_session() as sess:
271269
y_grid_array = sess.run(y_grid)
272270

273271
np.testing.assert_allclose(
274-
y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
272+
y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2)
273+
274+
def _test_integrate_sine_all(self, method):
275+
uniform_time_grid = np.linspace(0., 10., 200)
276+
non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0])
277+
uniform_dt = 0.02
278+
non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03])
279+
self._test_integrate_sine(method, uniform_time_grid)
280+
self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt)
281+
self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt)
282+
283+
def _test_integrate_gaussian_all(self, method):
284+
uniform_time_grid = np.linspace(0., 2., 100)
285+
non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0])
286+
uniform_dt = 0.01
287+
non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03])
288+
self._test_integrate_gaussian(method, uniform_time_grid)
289+
self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt)
290+
self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt)
275291

276292
def _test_everything(self, method):
277-
self._test_integrate_sine(method)
278-
self._test_integrate_gaussian(method)
293+
self._test_integrate_sine_all(method)
294+
self._test_integrate_gaussian_all(method)
279295

280296
def test_midpoint(self):
281297
self._test_everything('midpoint')
282298

283299
def test_rk4(self):
284300
self._test_everything('rk4')
285301

302+
def test_dt_size_exceptions(self):
303+
times = np.linspace(0., 2., 100)
304+
dt = np.ones(99) * 0.01
305+
dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03])
306+
dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0)
307+
times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0)
308+
with self.assertRaises(ValueError):
309+
self._test_integrate_gaussian('midpoint', times, dt_wrong_length)
310+
311+
with self.assertRaises(ValueError):
312+
self._test_integrate_gaussian('midpoint', times, dt_wrong_dim)
313+
314+
with self.assertRaises(ValueError):
315+
self._test_integrate_gaussian('midpoint', times_wrong_dim, dt)
316+
286317

287318
if __name__ == '__main__':
288319
test.main()

0 commit comments

Comments
 (0)