forked from SciML/Optimization.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcache.jl
84 lines (77 loc) · 3.16 KB
/
cache.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
function Base.getproperty(cache::SciMLBase.AbstractOptimizationCache, x::Symbol)
if x in (:u0, :p)
return getfield(cache.reinit_cache, x)
end
return getfield(cache, x)
end
SciMLBase.has_reinit(cache::SciMLBase.AbstractOptimizationCache) = true
function SciMLBase.reinit!(cache::SciMLBase.AbstractOptimizationCache; p = missing,
u0 = missing)
if p === missing && u0 === missing
p, u0 = cache.p, cache.u0
else # at least one of them has a value
if p === missing
p = cache.p
end
if u0 === missing
u0 = cache.u0
end
if (eltype(p) <: Pair && !isempty(p)) || (eltype(u0) <: Pair && !isempty(u0)) # one is a non-empty symbolic map
hasproperty(cache.f, :sys) && hasfield(typeof(cache.f.sys), :ps) ||
throw(ArgumentError("This cache does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
hasproperty(cache.f, :sys) && hasfield(typeof(cache.f.sys), :states) ||
throw(ArgumentError("This cache does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
p, u0 = SciMLBase.process_p_u0_symbolic(cache, p, u0)
end
end
cache.reinit_cache.p = p
cache.reinit_cache.u0 = u0
return cache
end
struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <:
SciMLBase.AbstractOptimizationCache
f::F
reinit_cache::RC
lb::LB
ub::UB
lcons::LC
ucons::UC
sense::S
opt::O
data::D
progress::P
callback::C
solver_args::NamedTuple
end
function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data;
callback = Optimization.DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
reinit_cache = Optimization.ReInitCache(prob.u0, prob.p)
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
f = Optimization.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons)
return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
prob.ucons, prob.sense,
opt, data, progress, callback,
merge((; maxiters, maxtime, abstol, reltol),
NamedTuple(kwargs)))
end
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt,
data = Optimization.DEFAULT_DATA;
callback = Optimization.DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
return OptimizationCache(prob, opt, data; maxiters, maxtime, abstol, callback,
reltol, progress,
kwargs...)
end