28
28
from tensorflow .python .framework import ops
29
29
from tensorflow .python .ops import array_ops
30
30
from tensorflow .python .ops import control_flow_ops
31
- from tensorflow .python .ops import functional_ops
32
31
from tensorflow .python .ops import math_ops
33
32
from tensorflow .python .ops import tensor_array_ops
34
33
@@ -279,13 +278,27 @@ def _assert_increasing(t):
279
278
return ops .control_dependencies ([assert_increasing ])
280
279
281
280
282
- def _check_input_types (t , y0 ):
281
+ def _check_input_types (y0 , t , dt = None ):
283
282
if not (y0 .dtype .is_floating or y0 .dtype .is_complex ):
284
283
raise TypeError ('`y0` must have a floating point or complex floating '
285
284
'point dtype' )
286
285
if not t .dtype .is_floating :
287
286
raise TypeError ('`t` must have a floating point dtype' )
288
287
288
+ if dt is not None and not dt .dtype .is_floating :
289
+ raise TypeError ('`dt` must have a floating point dtype' )
290
+
291
+
292
+ def _check_input_sizes (t , dt ):
293
+ if len (t .get_shape ().as_list ()) > 1 :
294
+ raise ValueError ('t must be a 1D tensor' )
295
+
296
+ if len (dt .get_shape ().as_list ()) > 1 :
297
+ raise ValueError ('t must be a 1D tensor' )
298
+
299
+ if t .get_shape ()[0 ] != dt .get_shape ()[0 ] + 1 :
300
+ raise ValueError ('t and dt have incompatible lengths, must be N and N-1' )
301
+
289
302
290
303
def _dopri5 (func ,
291
304
y0 ,
@@ -510,7 +523,7 @@ def odeint(func,
510
523
# avoiding the need to pack/unpack in user functions.
511
524
y0 = ops .convert_to_tensor (y0 , name = 'y0' )
512
525
t = ops .convert_to_tensor (t , preferred_dtype = dtypes .float64 , name = 't' )
513
- _check_input_types (t , y0 )
526
+ _check_input_types (y0 , t )
514
527
515
528
error_dtype = abs (y0 ).dtype
516
529
rtol = ops .convert_to_tensor (rtol , dtype = error_dtype , name = 'rtol' )
@@ -530,31 +543,82 @@ def odeint(func,
530
543
class _FixedGridIntegrator (six .with_metaclass (abc .ABCMeta )):
531
544
"""Base class for fixed-grid ODE integrators."""
532
545
533
- def integrate (self , evol_func , y0 , time_grid ):
534
- time_delta_grid = time_grid [1 :] - time_grid [:- 1 ]
535
-
536
- scan_func = self ._make_scan_func (evol_func )
546
+ def integrate (self , evol_func , y0 , time_grid , dt_grid , steps_on_intervals ):
547
+ """Returns integrated values of differential equation on the `time grid`.
548
+
549
+ Numerically integrates differential equation defined via time derivative
550
+ evaluator `evol_func` using fixed time steps specified in dt_grid.
551
+
552
+ Args:
553
+ evol_func: Callable, evaluates time derivative of y at a given time.
554
+ y0: N-D Tensor holds initial values of the solution.
555
+ time_grid: 1-D Tensor holding the time points at which the solution
556
+ will be recorded, must have a floating dtype.
557
+ dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
558
+ intervals. Must be a floating dtype and have one less element than that
559
+ of the time_grid.
560
+ steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
561
+ as dt_grid. Specifies number of steps needed for every interval. Assumes
562
+ steps_on_intervals * dt_grid == time intervals.
563
+
564
+ Returns:
565
+ (N+1)-D tensor, where the first dimension corresponds to different
566
+ time points. Contains the solved value of y for each desired time point in
567
+ `t`, with the initial value `y0` being the first element along the first
568
+ dimension.
569
+ """
537
570
538
- y_grid = functional_ops . scan ( scan_func , ( time_grid [: - 1 ], time_delta_grid ),
539
- y0 )
540
- return array_ops . concat ([[ y0 ], y_grid ], axis = 0 )
571
+ iteration_func = self . _make_iteration_func ( evol_func , dt_grid )
572
+ integrate_interval = self . _make_interval_integrator ( iteration_func ,
573
+ steps_on_intervals )
541
574
542
- def _make_scan_func (self , evol_func ):
575
+ num_times = array_ops .size (time_grid )
576
+ current_time = time_grid [0 ]
577
+ solution_array = tensor_array_ops .TensorArray (y0 .dtype , num_times )
578
+ solution_array = solution_array .write (0 , y0 )
543
579
544
- def scan_func (y , t_and_dt ):
545
- t , dt = t_and_dt
580
+ solution_array , _ , _ , _ = control_flow_ops .while_loop (
581
+ lambda _ , __ , ___ , i : i < num_times ,
582
+ integrate_interval ,
583
+ (solution_array , y0 , current_time , 1 )
584
+ )
585
+ solution_array = solution_array .stack ()
586
+ solution_array .set_shape (time_grid .get_shape ().concatenate (y0 .get_shape ()))
587
+ return solution_array
588
+
589
+ def _make_iteration_func (self , evol_func , dt_grid ):
590
+ """Returns a function that builds operations of a single time step."""
591
+
592
+ def iteration_func (y , t , dt_step , interval_step ):
593
+ """Performs a single time step advance."""
594
+ dt = dt_grid [interval_step - 1 ]
546
595
dy = self ._step_func (evol_func , t , dt , y )
547
596
dy = math_ops .cast (dy , dtype = y .dtype )
548
- return y + dy
597
+ return y + dy , t + dt , dt_step + 1 , interval_step
598
+
599
+ return iteration_func
600
+
601
+ def _make_interval_integrator (self , iteration_func , interval_sizes ):
602
+ """Returns a function that builds operations for interval integration."""
549
603
550
- return scan_func
604
+ def integrate_interval (solution_array , y , t , interval_num ):
605
+ """Integrates y with fixed time step on interval `interval_num`."""
606
+ y , t , _ , _ = control_flow_ops .while_loop (
607
+ lambda _ , __ , j , interval_num : j < interval_sizes [interval_num - 1 ],
608
+ iteration_func ,
609
+ (y , t , 0 , interval_num )
610
+ )
611
+ return solution_array .write (interval_num , y ), y , t , interval_num + 1
612
+
613
+ return integrate_interval
551
614
552
615
@abc .abstractmethod
553
616
def _step_func (self , evol_func , t , dt , y ):
554
617
pass
555
618
556
619
557
620
class _MidpointFixedGridIntegrator (_FixedGridIntegrator ):
621
+ """Fixed grid integrator implementing midpoint scheme."""
558
622
559
623
def _step_func (self , evol_func , t , dt , y ):
560
624
dt_cast = math_ops .cast (dt , y .dtype )
@@ -563,6 +627,7 @@ def _step_func(self, evol_func, t, dt, y):
563
627
564
628
565
629
class _RK4FixedGridIntegrator (_FixedGridIntegrator ):
630
+ """Fixed grid integrator implementing RK4 scheme."""
566
631
567
632
def _step_func (self , evol_func , t , dt , y ):
568
633
k1 = evol_func (y , t )
@@ -575,7 +640,7 @@ def _step_func(self, evol_func, t, dt, y):
575
640
return math_ops .add_n ([k1 , 2 * k2 , 2 * k3 , k4 ]) * (dt_cast / 6 )
576
641
577
642
578
- def odeint_fixed (func , y0 , t , method = 'rk4' , name = None ):
643
+ def odeint_fixed (func , y0 , t , dt = None , method = 'rk4' , name = None ):
579
644
"""ODE integration on a fixed grid (with no step size control).
580
645
581
646
Useful in certain scenarios to avoid the overhead of adaptive step size
@@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
590
655
`y`. The initial time point should be the first element of this sequence,
591
656
and each time must be larger than the previous time. May have any floating
592
657
point dtype.
658
+ dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
659
+ integration intervals in `t`. 1-D Tensor should provide values
660
+ for all intervals, must have 1 less element than that of `t`.
661
+ If given a 0-D Tensor, the value is interpreted as time step suggestion
662
+ same for all intervals. If passed None, then time step is set to be the
663
+ t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
664
+ insuring an integer number of steps per interval, potentially reducing the
665
+ time step.
593
666
method: One of 'midpoint' or 'rk4'.
594
667
name: Optional name for the resulting operation.
595
668
@@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
602
675
Raises:
603
676
ValueError: Upon caller errors.
604
677
"""
605
- with ops .name_scope (name , 'odeint_fixed' , [y0 , t ]):
678
+ with ops .name_scope (name , 'odeint_fixed' , [y0 , t , dt ]):
606
679
t = ops .convert_to_tensor (t , preferred_dtype = dtypes .float64 , name = 't' )
607
680
y0 = ops .convert_to_tensor (y0 , name = 'y0' )
608
- _check_input_types (t , y0 )
681
+
682
+ intervals = t [1 :] - t [:- 1 ]
683
+ if dt is None :
684
+ dt = intervals
685
+ dt = ops .convert_to_tensor (dt , preferred_dtype = dtypes .float64 , name = 'dt' )
686
+
687
+ steps_on_intervals = math_ops .ceil (intervals / dt )
688
+ dt = intervals / steps_on_intervals
689
+ steps_on_intervals = math_ops .cast (steps_on_intervals , dtype = dtypes .int32 )
690
+
691
+ _check_input_types (y0 , t , dt )
692
+ _check_input_sizes (t , dt )
609
693
610
694
with _assert_increasing (t ):
611
695
with ops .name_scope (method ):
612
696
if method == 'midpoint' :
613
- return _MidpointFixedGridIntegrator ().integrate (func , y0 , t )
697
+ return _MidpointFixedGridIntegrator ().integrate (func , y0 , t , dt ,
698
+ steps_on_intervals )
614
699
elif method == 'rk4' :
615
- return _RK4FixedGridIntegrator ().integrate (func , y0 , t )
700
+ return _RK4FixedGridIntegrator ().integrate (func , y0 , t , dt ,
701
+ steps_on_intervals )
616
702
else :
617
703
raise ValueError ('method not supported: {!s}' .format (method ))
0 commit comments