-
Notifications
You must be signed in to change notification settings - Fork 69
Description
To implement a Laplace approximated RV, we require an estimate of the optimal point in the latent field,
The way this is currently handled is by calling optimize.minimize
during each call to laplace_marginal_rv_logp
(logp of MarginalLaplaceRV
). Profiling reveals that using L-BFGS-B
against minimize
consumes ~80% of the logp call's runtime (see full profiling trace below). Increasing the dimensionality of pm.sample
).
There are several possible avenues to address this:
- Reduce the number of calls to
minimize
(this could be achieved by only making one call tominimize
per leapfrog step, or perhaps once every few calls, at the cost of accuracy). - Investigate further where in
minimize
the runtime is being bottlenecked. - Forgo
optimize.minimize
as a black box and implement the bespoke algorithm described by Rasmussen & Williams, and additionally the adjoint method described by the Stan paper (note thatminimize
should have the adjoint method implemented natively, however a bespoke method will require this to be implemented manually).
Function profiling
==================
Message: /home/michaln/git/pymc/pymc/pytensorf.py:942
Time in 81 calls to Function.__call__: 2.142292e+00s
Time in Function.vm.__call__: 2.1357465020005293s (99.694%)
Time in thunks: 2.1306798458099365s (99.458%)
Total compilation time: 2.125471e+00s
Number of Apply nodes: 32
PyTensor rewrite time: 5.579397e-01s
PyTensor validate time: 9.109666e-03s
PyTensor Linker time (includes C, CUDA code generation/compiling): 1.5629720789997918s
C-cache preloading 1.097559e-02s
Import time 2.067265e-03s
Node make_thunk time 1.550541e+00s
Node Composite{((-5.513631199228036 + i3) - ((0.5 * i2) + (0.5 * i0) + -2.756815599614018 + i1))}(Squeeze{axis=0}.0, CAReduce{Composite{(i0 + log(i1))}, axes=None}.0, Squeeze{axis=0}.0, Sum{axes=None}.0) time 1.109129e+00s
Node MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.7739560 ... .85859792], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]]) time 2.768553e-01s
Node Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0) time 1.447895e-01s
Node ExpandDims{axis=0}(mu) time 1.893509e-03s
Node Squeeze{axis=0}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0) time 1.311117e-03s
Time in all call to pytensor.grad() 2.084509e+00s
Time since pytensor import 768.571s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
80.9% 80.9% 1.724s 2.13e-02s Py 81 1 pytensor.tensor.optimize.MinimizeOp
12.1% 93.0% 0.258s 3.19e-03s Py 81 1 pytensor.scan.op.Scan
5.2% 98.2% 0.110s 3.39e-04s Py 324 4 pytensor.tensor.slinalg.SolveTriangular
0.6% 98.8% 0.014s 1.86e-05s C 729 9 pytensor.tensor.elemwise.Elemwise
0.6% 99.4% 0.012s 7.31e-05s C 162 2 pytensor.tensor.math.Sum
0.4% 99.8% 0.009s 2.64e-05s C 324 4 pytensor.tensor.elemwise.CAReduce
0.1% 99.9% 0.003s 3.43e-05s Py 81 1 pytensor.tensor.slinalg.Cholesky
0.1% 100.0% 0.001s 1.64e-06s C 729 9 pytensor.tensor.elemwise.DimShuffle
0.0% 100.0% 0.000s 2.79e-06s C 81 1 pytensor.tensor.basic.ExtractDiag
... (remaining 0 Classes account for 0.00%(0.00s) of the runtime)
Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
80.9% 80.9% 1.724s 2.13e-02s Py 81 1 MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)
12.1% 93.0% 0.258s 3.19e-03s Py 81 1 Scan{scan_fn, while_loop=False, inplace=none}
3.5% 96.5% 0.074s 3.06e-04s Py 243 3 SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}
1.7% 98.2% 0.036s 4.41e-04s Py 81 1 SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True}
0.5% 98.7% 0.010s 1.21e-04s C 81 1 Sum{axis=0}
0.4% 99.0% 0.008s 3.27e-05s C 243 3 CAReduce{Composite{(i0 + sqr(i1))}, axis=1}
0.4% 99.4% 0.008s 3.19e-05s C 243 3 Sub
0.2% 99.6% 0.004s 1.80e-05s C 243 3 Neg
0.1% 99.7% 0.003s 3.43e-05s Py 81 1 Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}
0.1% 99.8% 0.002s 2.48e-05s C 81 1 Sum{axes=None}
0.1% 99.9% 0.001s 1.33e-05s C 81 1 Composite{(-2.756815599614018 - (0.5 * i0))}
0.0% 99.9% 0.001s 7.44e-06s C 81 1 CAReduce{Composite{(i0 + log(i1))}, axes=None}
0.0% 99.9% 0.001s 1.69e-06s C 324 4 Transpose{axes=[1, 0]}
0.0% 100.0% 0.000s 1.65e-06s C 162 2 ExpandDims{axis=1}
0.0% 100.0% 0.000s 2.79e-06s C 81 1 ExtractDiag{offset=0, axis1=0, axis2=1, view=True}
0.0% 100.0% 0.000s 2.37e-06s C 81 1 ExpandDims{axis=0}
0.0% 100.0% 0.000s 1.15e-06s C 162 2 Squeeze{axis=0}
0.0% 100.0% 0.000s 2.26e-06s C 81 1 Composite{((-5.513631199228036 + i3) - ((0.5 * i2) + (0.5 * i0) + -2.756815599614018 + i1))}
0.0% 100.0% 0.000s 2.13e-06s C 81 1 Sub
... (remaining 0 Ops account for 0.00%(0.00s) of the runtime)
Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
80.9% 80.9% 1.724s 2.13e-02s 81 2 MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.7739560 ... .85859792], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]])
12.1% 93.0% 0.258s 3.19e-03s 81 26 Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0)
3.0% 96.0% 0.063s 7.78e-04s 81 9 SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
1.7% 97.7% 0.036s 4.41e-04s 81 20 SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Neg.0)
0.5% 98.1% 0.010s 1.21e-04s 81 23 Sum{axis=0}(Transpose{axes=[1, 0]}.0)
0.4% 98.5% 0.007s 9.24e-05s 81 15 CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
0.3% 98.8% 0.007s 8.59e-05s 81 6 Sub([[-0.69153 ... 5232307 ]], ExpandDims{axis=1}.0)
0.3% 99.1% 0.006s 7.67e-05s 81 11 SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
0.2% 99.3% 0.005s 6.20e-05s 81 5 SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
0.1% 99.5% 0.003s 3.43e-05s 81 28 Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}(Sub.0)
0.1% 99.6% 0.002s 2.67e-05s 81 17 Neg(SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}.0)
0.1% 99.7% 0.002s 2.56e-05s 81 24 Neg(Transpose{axes=[1, 0]}.0)
0.1% 99.8% 0.002s 2.48e-05s 81 21 Sum{axes=None}(posdef covariance)
0.1% 99.8% 0.001s 1.33e-05s 81 18 Composite{(-2.756815599614018 - (0.5 * i0))}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0)
0.0% 99.9% 0.001s 7.44e-06s 81 30 CAReduce{Composite{(i0 + log(i1))}, axes=None}(ExtractDiag{offset=0, axis1=0, axis2=1, view=True}.0)
0.0% 99.9% 0.001s 6.21e-06s 81 27 Sub([[1. 0. 0. ... 0. 0. 1.]], Scan{scan_fn, while_loop=False, inplace=none}.0)
0.0% 99.9% 0.000s 3.48e-06s 81 3 Sub(ExpandDims{axis=1}.0, [[2.273360 ... 97365457]])
0.0% 99.9% 0.000s 3.32e-06s 81 16 CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
0.0% 99.9% 0.000s 2.79e-06s 81 29 ExtractDiag{offset=0, axis1=0, axis2=1, view=True}(Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}.0)
0.0% 99.9% 0.000s 2.44e-06s 81 10 CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
... (remaining 12 Apply instances account for 0.08%%(0.00s) of the runtime)
Scan overhead:
<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>
One scan node do not have its inner profile enabled. If you enable PyTensor profiler with 'pytensor.function(..., profile=True)', you must manually enable the profiling for each scan too: 'pytensor.scan(...,profile=True)'. Or use PyTensor flag 'profile=True'.
No scan have its inner profile enabled.
Here are tips to potentially make your code run faster
(if you think of new ones, suggest them on the mailing list).
Test them first, as they are not guaranteed to always provide a speedup.
- Try the PyTensor flag floatX=float32```