-
-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathmultiple_shooting.jl
216 lines (180 loc) · 9.41 KB
/
multiple_shooting.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
multiple_shoot(p, ode_data, tsteps, prob, loss_function,
[continuity_loss = _default_continuity_loss], solver, group_size;
continuity_term = 100, kwargs...)
Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of
predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the
Neural Network divides the interval into smaller intervals and solves for them separately.
The default continuity term is 100, implying any losses arising from the non-continuity
of 2 different groups will be scaled by 100.
Arguments:
- `p`: The parameters of the Neural Network to be trained.
- `ode_data`: Original Data to be modelled.
- `tsteps`: Timesteps on which ode_data was calculated.
- `prob`: ODE problem that the Neural Network attempts to solve.
- `loss_function`: Any arbitrary function to calculate loss.
- `continuity_loss`: Function that takes states ``\\hat{u}_{end}`` of group ``k`` and
``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss between
them. If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used.
- `solver`: ODE Solver algorithm.
- `group_size`: The group size achieved after splitting the ode_data into equal sizes.
- `continuity_term`: Weight term to ensure continuity of predictions throughout
different groups.
- `kwargs`: Additional arguments splatted to the ODE solver. Refer to the
[Local Sensitivity Analysis](https://docs.sciml.ai/SciMLSensitivity/stable/) and
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
!!! note
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
whenever the last point of any group doesn't coincide with the first point of next group.
"""
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
datasize = size(ode_data, 2)
if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
end
# Get ranges that partition data to groups of size group_size
ranges = group_ranges(datasize, group_size)
# Multiple shooting predictions
sols = [solve(
remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
u0 = ode_data[:, first(rg)]),
solver;
saveat = tsteps[rg],
kwargs...) for rg in ranges]
group_predictions = Array.(sols)
# Abort and return infinite loss if one of the integrations failed
retcodes = [sol.retcode for sol in sols]
all(SciMLBase.successful_retcode, retcodes) || return Inf, group_predictions
# Calculate multiple shooting loss
loss = 0
for (i, rg) in enumerate(ranges)
u = ode_data[:, rg]
û = group_predictions[i]
loss += loss_function(u, û)
if i > 1
# Ensure continuity between last state in previous prediction
# and current initial condition in ode_data
loss += continuity_term *
continuity_loss(group_predictions[i - 1][:, end], u[:, 1])
end
end
return loss, group_predictions
end
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; kwargs...) where {F}
return multiple_shoot(p, ode_data, tsteps, prob, loss_function,
_default_continuity_loss, solver, group_size; kwargs...)
end
"""
multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
[continuity_loss = _default_continuity_loss], solver, group_size;
continuity_term = 100, kwargs...)
Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of
predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the
Neural Network divides the interval into smaller intervals and solves for them separately.
The default continuity term is 100, implying any losses arising from the non-continuity
of 2 different groups will be scaled by 100.
Arguments:
- `p`: The parameters of the Neural Network to be trained.
- `ode_data`: Original Data to be modelled.
- `tsteps`: Timesteps on which ode_data was calculated.
- `ensemble_prob`: Ensemble problem that the Neural Network attempts to solve.
- `ensemble_alg`: Ensemble algorithm, e.g. `EnsembleThreads()`.
- `prob`: ODE problem that the Neural Network attempts to solve.
- `loss_function`: Any arbitrary function to calculate loss.
- `continuity_loss`: Function that takes states ``\\hat{u}_{end}`` of group ``k`` and
``u_{0}`` of group ``k+1`` as input and calculates prediction continuity loss between
them. If no custom `continuity_loss` is specified, `sum(abs, û_end - u_0)` is used.
- `solver`: ODE Solver algorithm.
- `group_size`: The group size achieved after splitting the ode_data into equal sizes.
- `continuity_term`: Weight term to ensure continuity of predictions throughout different groups.
- `kwargs`: Additional arguments splatted to the ODE solver. Refer to the
[Local Sensitivity Analysis](https://docs.sciml.ai/SciMLSensitivity/stable/) and
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/) documentation for more details.
!!! note
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty
whenever the last point of any group doesn't coincide with the first point of next group.
"""
function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
datasize = size(ode_data, 2)
prob = ensembleprob.prob
if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
end
@assert ndims(ode_data)==3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
@assert size(ode_data, 2) == length(tsteps)
@assert size(ode_data, 3) == kwargs[:trajectories]
# Get ranges that partition data to groups of size group_size
ranges = group_ranges(datasize, group_size)
# Multiple shooting predictions by using map we avoid mutating an array
sols = map(
rg -> begin
newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)]))
function prob_func(prob, i, repeat)
remake(prob; u0 = ode_data[:, first(rg), i])
end
newensembleprob = EnsembleProblem(
newprob, prob_func, ensembleprob.output_func, ensembleprob.reduction,
ensembleprob.u_init, ensembleprob.safetycopy)
solve(newensembleprob, solver, ensemblealg; saveat = tsteps[rg], kwargs...)
end,
ranges)
group_predictions = Array.(sols)
# Abort and return infinite loss if one of the integrations did not converge?
convergeds = [sol.converged for sol in sols]
any(.!convergeds) && return Inf, group_predictions
# Calculate multiple shooting loss
loss = 0
for (i, rg) in enumerate(ranges)
û = group_predictions[i]
u = ode_data[:, rg, :] # trajectories are at dims 3
# just summing up losses for all trajectories
# but other alternatives might be considered
loss += loss_function(u, û)
if i > 1
# Ensure continuity between last state in previous prediction
# and current initial condition in ode_data
loss += continuity_term *
continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :])
end
end
return loss, group_predictions
end
function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer;
continuity_term::Real = 100, kwargs...) where {F}
return multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
_default_continuity_loss, solver, group_size; continuity_term, kwargs...)
end
"""
group_ranges(datasize, groupsize)
Get ranges that partition data of length `datasize` in groups of `groupsize` observations.
If the data isn't perfectly dividable by `groupsize`, the last group contains
the reminding observations.
Arguments:
- `datasize`: amount of data points to be partitioned.
- `groupsize`: maximum amount of observations in each group.
Example:
```julia-repl
julia> group_ranges(10, 5)
3-element Vector{UnitRange{Int64}}:
1:5
5:9
9:10
```
"""
function group_ranges(datasize::Integer, groupsize::Integer)
2 ≤ groupsize ≤ datasize || throw(DomainError(groupsize,
"datasize must be positive and groupsize must to be within [2, datasize]"))
return [i:min(datasize, i + groupsize - 1) for i in 1:(groupsize - 1):(datasize - 1)]
end
# Default ontinuity loss between last state in previous prediction
# and current initial condition in ode_data
_default_continuity_loss(û_end, u_0) = sum(abs, û_end - u_0)