Skip to content

Calls to pytensor.optimize.minimize slow down laplace_marginal_rv_logp #568

@Michal-Novomestsky

Description

@Michal-Novomestsky

To implement a Laplace approximated RV, we require an estimate of the optimal point in the latent field, $x^* (y, \theta)$, which we obtain by maximizing $\textnormal{logp}(x \mid y, \theta)$. However, since $x^*$ is a function of $\theta$, this needs to be computed for each case of $\theta$ during sampling.

The way this is currently handled is by calling optimize.minimize during each call to laplace_marginal_rv_logp (logp of MarginalLaplaceRV). Profiling reveals that using L-BFGS-B against $x \in R^3$, minimize consumes ~80% of the logp call's runtime (see full profiling trace below). Increasing the dimensionality of $x$ to anything practical (100+) makes the sampling runtime completely impractical (hours, compared to ~30s on direct sampling with pm.sample).

There are several possible avenues to address this:

  • Reduce the number of calls to minimize (this could be achieved by only making one call to minimize per leapfrog step, or perhaps once every few calls, at the cost of accuracy).
  • Investigate further where in minimize the runtime is being bottlenecked.
  • Forgo optimize.minimize as a black box and implement the bespoke algorithm described by Rasmussen & Williams, and additionally the adjoint method described by the Stan paper (note that minimize should have the adjoint method implemented natively, however a bespoke method will require this to be implemented manually).
Function profiling
==================
  Message: /home/michaln/git/pymc/pymc/pytensorf.py:942
  Time in 81 calls to Function.__call__: 2.142292e+00s
  Time in Function.vm.__call__: 2.1357465020005293s (99.694%)
  Time in thunks: 2.1306798458099365s (99.458%)
  Total compilation time: 2.125471e+00s
    Number of Apply nodes: 32
    PyTensor rewrite time: 5.579397e-01s
       PyTensor validate time: 9.109666e-03s
    PyTensor Linker time (includes C, CUDA code generation/compiling): 1.5629720789997918s
       C-cache preloading 1.097559e-02s
       Import time 2.067265e-03s
       Node make_thunk time 1.550541e+00s
           Node Composite{((-5.513631199228036 + i3) - ((0.5 * i2) + (0.5 * i0) + -2.756815599614018 + i1))}(Squeeze{axis=0}.0, CAReduce{Composite{(i0 + log(i1))}, axes=None}.0, Squeeze{axis=0}.0, Sum{axes=None}.0) time 1.109129e+00s
           Node MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.7739560 ... .85859792], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]]) time 2.768553e-01s
           Node Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0) time 1.447895e-01s
           Node ExpandDims{axis=0}(mu) time 1.893509e-03s
           Node Squeeze{axis=0}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0) time 1.311117e-03s

Time in all call to pytensor.grad() 2.084509e+00s
Time since pytensor import 768.571s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  80.9%    80.9%       1.724s       2.13e-02s     Py      81       1   pytensor.tensor.optimize.MinimizeOp
  12.1%    93.0%       0.258s       3.19e-03s     Py      81       1   pytensor.scan.op.Scan
   5.2%    98.2%       0.110s       3.39e-04s     Py     324       4   pytensor.tensor.slinalg.SolveTriangular
   0.6%    98.8%       0.014s       1.86e-05s     C      729       9   pytensor.tensor.elemwise.Elemwise
   0.6%    99.4%       0.012s       7.31e-05s     C      162       2   pytensor.tensor.math.Sum
   0.4%    99.8%       0.009s       2.64e-05s     C      324       4   pytensor.tensor.elemwise.CAReduce
   0.1%    99.9%       0.003s       3.43e-05s     Py      81       1   pytensor.tensor.slinalg.Cholesky
   0.1%   100.0%       0.001s       1.64e-06s     C      729       9   pytensor.tensor.elemwise.DimShuffle
   0.0%   100.0%       0.000s       2.79e-06s     C       81       1   pytensor.tensor.basic.ExtractDiag
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  80.9%    80.9%       1.724s       2.13e-02s     Py      81        1   MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)
  12.1%    93.0%       0.258s       3.19e-03s     Py      81        1   Scan{scan_fn, while_loop=False, inplace=none}
   3.5%    96.5%       0.074s       3.06e-04s     Py     243        3   SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}
   1.7%    98.2%       0.036s       4.41e-04s     Py      81        1   SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True}
   0.5%    98.7%       0.010s       1.21e-04s     C       81        1   Sum{axis=0}
   0.4%    99.0%       0.008s       3.27e-05s     C      243        3   CAReduce{Composite{(i0 + sqr(i1))}, axis=1}
   0.4%    99.4%       0.008s       3.19e-05s     C      243        3   Sub
   0.2%    99.6%       0.004s       1.80e-05s     C      243        3   Neg
   0.1%    99.7%       0.003s       3.43e-05s     Py      81        1   Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}
   0.1%    99.8%       0.002s       2.48e-05s     C       81        1   Sum{axes=None}
   0.1%    99.9%       0.001s       1.33e-05s     C       81        1   Composite{(-2.756815599614018 - (0.5 * i0))}
   0.0%    99.9%       0.001s       7.44e-06s     C       81        1   CAReduce{Composite{(i0 + log(i1))}, axes=None}
   0.0%    99.9%       0.001s       1.69e-06s     C      324        4   Transpose{axes=[1, 0]}
   0.0%   100.0%       0.000s       1.65e-06s     C      162        2   ExpandDims{axis=1}
   0.0%   100.0%       0.000s       2.79e-06s     C       81        1   ExtractDiag{offset=0, axis1=0, axis2=1, view=True}
   0.0%   100.0%       0.000s       2.37e-06s     C       81        1   ExpandDims{axis=0}
   0.0%   100.0%       0.000s       1.15e-06s     C      162        2   Squeeze{axis=0}
   0.0%   100.0%       0.000s       2.26e-06s     C       81        1   Composite{((-5.513631199228036 + i3) - ((0.5 * i2) + (0.5 * i0) + -2.756815599614018 + i1))}
   0.0%   100.0%       0.000s       2.13e-06s     C       81        1   Sub
   ... (remaining 0 Ops account for   0.00%(0.00s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  80.9%    80.9%       1.724s       2.13e-02s     81     2   MinimizeOp(method=L-BFGS-B, jac=True, hess=False, hessp=False)([0.7739560 ... .85859792], True, [0.], [[2]], [[[1. 0. 0 ... . 0. 1.]]], [0.5], True, 0.0, [[2]], ExpandDims{axis=0}.0, [[[1. 0. 0 ... . 0. 1.]]])
  12.1%    93.0%       0.258s       3.19e-03s     81    26   Scan{scan_fn, while_loop=False, inplace=none}(3, [0 1 2], 3, Neg.0, Neg.0)
   3.0%    96.0%       0.063s       7.78e-04s     81     9   SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
   1.7%    97.7%       0.036s       4.41e-04s     81    20   SolveTriangular{unit_diagonal=False, lower=False, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Neg.0)
   0.5%    98.1%       0.010s       1.21e-04s     81    23   Sum{axis=0}(Transpose{axes=[1, 0]}.0)
   0.4%    98.5%       0.007s       9.24e-05s     81    15   CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
   0.3%    98.8%       0.007s       8.59e-05s     81     6   Sub([[-0.69153 ... 5232307 ]], ExpandDims{axis=1}.0)
   0.3%    99.1%       0.006s       7.67e-05s     81    11   SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
   0.2%    99.3%       0.005s       6.20e-05s     81     5   SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
   0.1%    99.5%       0.003s       3.43e-05s     81    28   Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}(Sub.0)
   0.1%    99.6%       0.002s       2.67e-05s     81    17   Neg(SolveTriangular{unit_diagonal=False, lower=True, check_finite=True, b_ndim=2, overwrite_b=True}.0)
   0.1%    99.7%       0.002s       2.56e-05s     81    24   Neg(Transpose{axes=[1, 0]}.0)
   0.1%    99.8%       0.002s       2.48e-05s     81    21   Sum{axes=None}(posdef covariance)
   0.1%    99.8%       0.001s       1.33e-05s     81    18   Composite{(-2.756815599614018 - (0.5 * i0))}(CAReduce{Composite{(i0 + sqr(i1))}, axis=1}.0)
   0.0%    99.9%       0.001s       7.44e-06s     81    30   CAReduce{Composite{(i0 + log(i1))}, axes=None}(ExtractDiag{offset=0, axis1=0, axis2=1, view=True}.0)
   0.0%    99.9%       0.001s       6.21e-06s     81    27   Sub([[1. 0. 0. ... 0. 0. 1.]], Scan{scan_fn, while_loop=False, inplace=none}.0)
   0.0%    99.9%       0.000s       3.48e-06s     81     3   Sub(ExpandDims{axis=1}.0, [[2.273360 ... 97365457]])
   0.0%    99.9%       0.000s       3.32e-06s     81    16   CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
   0.0%    99.9%       0.000s       2.79e-06s     81    29   ExtractDiag{offset=0, axis1=0, axis2=1, view=True}(Cholesky{lower=True, check_finite=False, on_error='nan', overwrite_a=True}.0)
   0.0%    99.9%       0.000s       2.44e-06s     81    10   CAReduce{Composite{(i0 + sqr(i1))}, axis=1}(Transpose{axes=[1, 0]}.0)
   ... (remaining 12 Apply instances account for 0.08%%(0.00s) of the runtime)


Scan overhead:
<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>
  One scan node do not have its inner profile enabled. If you enable PyTensor profiler with 'pytensor.function(..., profile=True)', you must manually enable the profiling for each scan too: 'pytensor.scan(...,profile=True)'. Or use PyTensor flag 'profile=True'.
  No scan have its inner profile enabled.
Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the PyTensor flag floatX=float32```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions