From 52f7fe16df4adf1326cae5578a0d7beac4c2dd11 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Fri, 10 Jan 2025 04:18:10 +0800 Subject: [PATCH 01/25] Foward `compile_kwargs` to ADVI when `init = "advi+..."` (#7640) --- pymc/pytensorf.py | 26 ------------------- pymc/sampling/mcmc.py | 4 +++ pymc/variational/inference.py | 19 ++++++++++---- pymc/variational/opvi.py | 49 +++++++++++++++++++++++++++-------- 4 files changed, 56 insertions(+), 42 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6fd44b0382..f665d5931c 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -22,7 +22,6 @@ import pytensor.tensor as pt import scipy.sparse as sps -from pytensor import scalar from pytensor.compile import Function, Mode, get_mode from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad @@ -415,31 +414,6 @@ def hessian_diag(f, vars=None, negate_output=True): return empty_gradient -class IdentityOp(scalar.UnaryScalarOp): - @staticmethod - def st_impl(x): - return x - - def impl(self, x): - return x - - def grad(self, inp, grads): - return grads - - def c_code(self, node, name, inp, out, sub): - return f"{out[0]} = {inp[0]};" - - def __eq__(self, other): - return isinstance(self, type(other)) - - def __hash__(self): - return hash(type(self)) - - -scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity") -identity = Elemwise(scalar_identity, name="identity") - - def make_shared_replacements(point, vars, model): """ Make shared replacements for all *other* variables than the ones passed. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index fce64e3b38..b2d643a5f1 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1553,6 +1553,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False @@ -1566,6 +1567,7 @@ def init_nuts( potential = quadpotential.QuadPotentialDiagAdapt( n, mean, cov, weight, rng=random_seed_list[0] ) + elif init == "advi": approx = pm.fit( random_seed=random_seed_list[0], @@ -1575,6 +1577,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False @@ -1592,6 +1595,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 3e2c07788f..29800e0541 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -82,9 +82,18 @@ def _maybe_score(self, score): def run_profiling(self, n=1000, score=None, **kwargs): score = self._maybe_score(score) - fn_kwargs = kwargs.pop("fn_kwargs", {}) - fn_kwargs["profile"] = True - step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs) + if "fn_kwargs" in kwargs: + warnings.warn( + "fn_kwargs is deprecated, please use compile_kwargs instead", DeprecationWarning + ) + compile_kwargs = kwargs.pop("fn_kwargs") + else: + compile_kwargs = kwargs.pop("compile_kwargs", {}) + + compile_kwargs["profile"] = True + step_func = self.objective.step_function( + score=score, compile_kwargs=compile_kwargs, **kwargs + ) try: for _ in track(range(n)): step_func() @@ -134,7 +143,7 @@ def fit( Add custom updates to resulting updates total_grad_norm_constraint: `float` Bounds gradient norm, prevents exploding gradient problem - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) more_replacements: `dict` Apply custom replacements before calculating gradients @@ -729,7 +738,7 @@ def fit( Add custom updates to resulting updates total_grad_norm_constraint: `float` Bounds gradient norm, prevents exploding gradient problem - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) more_replacements: `dict` Apply custom replacements before calculating gradients diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 9829ea2c35..034e2fed87 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -61,6 +61,8 @@ from pytensor.graph.basic import Variable from pytensor.graph.replace import graph_replace +from pytensor.scalar.basic import identity as scalar_identity +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import unbroadcast import pymc as pm @@ -74,7 +76,6 @@ SeedSequenceSeed, compile, find_rng_nodes, - identity, reseed_rngs, ) from pymc.util import ( @@ -332,6 +333,7 @@ def step_function( more_replacements=None, total_grad_norm_constraint=None, score=False, + compile_kwargs=None, fn_kwargs=None, ): R"""Step function that should be called on each optimization step. @@ -362,8 +364,13 @@ def step_function( Bounds gradient norm, prevents exploding gradient problem score: `bool` calculate loss on each step? Defaults to False for speed - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) + fn_kwargs: dict + arbitrary kwargs passed to `pytensor.function` + + .. warning:: `fn_kwargs` is deprecated and will be removed in future versions + more_replacements: `dict` Apply custom replacements before calculating gradients @@ -371,8 +378,16 @@ def step_function( ------- `pytensor.function` """ - if fn_kwargs is None: - fn_kwargs = {} + if fn_kwargs is not None: + warnings.warn( + "`fn_kwargs` is deprecated and will be removed in future versions. Use " + "`compile_kwargs` instead.", + DeprecationWarning, + ) + compile_kwargs = fn_kwargs + + if compile_kwargs is None: + compile_kwargs = {} if score and not self.op.returns_loss: raise NotImplementedError(f"{self.op} does not have loss") updates = self.updates( @@ -388,14 +403,14 @@ def step_function( ) seed = self.approx.rng.randint(2**30, dtype=np.int64) if score: - step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs) + step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs) else: - step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs) + step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs) return step_fn @pytensor.config.change_flags(compute_test_value="off") def score_function( - self, sc_n_mc=None, more_replacements=None, fn_kwargs=None + self, sc_n_mc=None, more_replacements=None, compile_kwargs=None, fn_kwargs=None ): # pragma: no cover R"""Compile scoring function that operates which takes no inputs and returns Loss. @@ -405,22 +420,34 @@ def score_function( number of scoring MC samples more_replacements: Apply custom replacements before compiling a function + compile_kwargs: `dict` + arbitrary kwargs passed to `pytensor.function` fn_kwargs: `dict` arbitrary kwargs passed to `pytensor.function` + .. warning:: `fn_kwargs` is deprecated and will be removed in future versions + Returns ------- pytensor.function """ - if fn_kwargs is None: - fn_kwargs = {} + if fn_kwargs is not None: + warnings.warn( + "`fn_kwargs` is deprecated and will be removed in future versions. Use " + "`compile_kwargs` instead", + DeprecationWarning, + ) + compile_kwargs = fn_kwargs + + if compile_kwargs is None: + compile_kwargs = {} if not self.op.returns_loss: raise NotImplementedError(f"{self.op} does not have loss") if more_replacements is None: more_replacements = {} loss = self(sc_n_mc, more_replacements=more_replacements) seed = self.approx.rng.randint(2**30, dtype=np.int64) - return compile([], loss, random_seed=seed, **fn_kwargs) + return compile([], loss, random_seed=seed, **compile_kwargs) @pytensor.config.change_flags(compute_test_value="off") def __call__(self, nmc, **kwargs): @@ -451,7 +478,7 @@ class Operator: require_logq = True objective_class = ObjectiveFunction supports_aevb = property(lambda self: not self.approx.any_histograms) - T = identity + T = Elemwise(scalar_identity) def __init__(self, approx): self.approx = approx From 70cb73c5178c401668d27305508f11a7863ee2bb Mon Sep 17 00:00:00 2001 From: Demetri Pananos Date: Thu, 9 Jan 2025 15:28:00 -0500 Subject: [PATCH 02/25] Adds shape/rate info to Gamma docs (#7625) --- pymc/distributions/continuous.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d5e69fc799..3746f90fac 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2311,6 +2311,8 @@ class Gamma(PositiveContinuous): f(x \mid \alpha, \beta) = \frac{\beta^{\alpha}x^{\alpha-1}e^{-\beta x}}{\Gamma(\alpha)} + Here, the gamma distribution is parameterized by shape (alpha) and rate (beta). + .. plot:: :context: close-figs From 9eb60ebeceee9a95fcbadb631eca38120c528562 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 23 Oct 2024 11:24:10 +0200 Subject: [PATCH 03/25] Make drop_warning_stat work with flat stat names for compound steps --- pymc/util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 8a059d7e0d..63576676eb 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import re import warnings from collections import namedtuple @@ -276,7 +277,12 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: nidata = arviz.InferenceData(attrs=idata.attrs) for gname, group in idata.items(): if "sample_stat" in gname: - group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore") + warning_vars = [ + name + for name in group.data_vars + if name == "warning" or re.match(r"sampler_\d+__warning", str(name)) + ] + group = group.drop_vars(names=[*warning_vars, "warning_dim_0"], errors="ignore") nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims) return nidata From a773405d217e3b3cf4df74d37913123069204f52 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 16 Oct 2024 11:51:46 +0200 Subject: [PATCH 04/25] Add ZarrTrace --- .github/workflows/tests.yml | 2 + conda-envs/environment-dev.yml | 1 + conda-envs/environment-docs.yml | 1 + conda-envs/environment-jax.yml | 1 + conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-dev.yml | 1 + conda-envs/windows-environment-test.yml | 1 + docs/source/api/backends.rst | 2 + docs/source/conf.py | 1 + pymc/backends/zarr.py | 846 ++++++++++++++++++++++++ requirements-dev.txt | 1 + tests/backends/test_zarr.py | 359 ++++++++++ 12 files changed, 1217 insertions(+) create mode 100644 pymc/backends/zarr.py create mode 100644 tests/backends/test_zarr.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 684ccb68a3..be2444921d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -115,6 +115,7 @@ jobs: - | tests/backends/test_mcbackend.py + tests/backends/test_zarr.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py @@ -240,6 +241,7 @@ jobs: - | tests/backends/test_arviz.py + tests/backends/test_zarr.py tests/variational/test_updates.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 71b6c78ed4..de0572e0a2 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -19,6 +19,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index f795fca078..c399a3e24a 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -17,6 +17,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 48649a617d..39deb8a41a 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -10,6 +10,7 @@ dependencies: - cachetools>=4.2.1 - cloudpickle - h5py>=2.7 +- zarr>=2.5.0,<3 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index e6fe9857e0..79c57a44c6 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -21,6 +21,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index ee5bd206f4..bbcba9149f 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -20,6 +20,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - myst-nb<=1.0.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fa59852830..399fab811b 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -23,6 +23,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/docs/source/api/backends.rst b/docs/source/api/backends.rst index ca00a56d81..8f0c76f453 100644 --- a/docs/source/api/backends.rst +++ b/docs/source/api/backends.rst @@ -20,3 +20,5 @@ Internal structures NDArray base.BaseTrace base.MultiTrace + zarr.ZarrTrace + zarr.ZarrChain diff --git a/docs/source/conf.py b/docs/source/conf.py index 74ac0d9746..b9afc12e73 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -309,6 +309,7 @@ "python": ("https://docs.python.org/3/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "xarray": ("https://docs.xarray.dev/en/stable/", None), + "zarr": ("https://zarr.readthedocs.io/en/stable/", None), } diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py new file mode 100644 index 0000000000..1e3f4da883 --- /dev/null +++ b/pymc/backends/zarr.py @@ -0,0 +1,846 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any + +import arviz as az +import numcodecs +import numpy as np +import xarray as xr +import zarr + +from arviz.data.base import make_attrs +from arviz.data.inference_data import WARMUP_TAG +from numcodecs.abc import Codec +from pytensor.tensor.variable import TensorVariable + +import pymc + +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.backends.base import BaseTrace +from pymc.blocking import StatDtype, StatShape +from pymc.model.core import Model, modelcontext +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + get_stats_dtypes_shapes_from_steps, +) +from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name + +try: + from zarr.storage import BaseStore, default_compressor + from zarr.sync import Synchronizer + + _zarr_available = True +except ImportError: + _zarr_available = False + + +class ZarrChain(BaseTrace): + """Interface object to interact with a single chain in a :class:`~.ZarrTrace`. + + Parameters + ---------- + store : zarr.storage.BaseStore | collections.abc.MutableMapping + The store object where the zarr groups and arrays will be stored and read from. + This store must exist before creating a ``ZarrChain`` object. ``ZarrChain`` are + only intended to be used as interfaces to the individual chains of + :class:`~.ZarrTrace` objects. This means that the :class:`~.ZarrTrace` should + be the one that creates the store that is then provided to a ``ZarrChain``. + stats_bijection : pymc.step_methods.compound.StatsBijection + An object that maps between a list of step method stats and a dictionary of + said stats with the accompanying stepper index. + synchronizer : zarr.sync.Synchronizer | None + The synchronizer to use for the underlying zarr arrays. + model : Model + If None, the model is taken from the `with` context. + vars : Sequence[TensorVariable] | None + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + test_point : dict[str, numpy.ndarray] | None + This is not used and is inherited from the signature of :class:`~.BaseTrace`, + which uses it to determine the shape and dtype of `vars`. + draws_per_chunk : int + The number of draws that make up a chunk in the variable's posterior array. + The interface only writes the samples to the store once a chunk is completely + filled. + """ + + def __init__( + self, + store: BaseStore | MutableMapping, + stats_bijection: StatsBijection, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + draws_per_chunk: int = 1, + ): + if not _zarr_available: + raise RuntimeError("You must install zarr to be able to create ZarrChain instances") + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + self._step_method: BlockedStep | CompoundStep | None = None + self.unconstrained_variables = { + var.name for var in self.vars if is_transformed_name(var.name) + } + self.draw_idx = 0 + self._buffers: dict[str, dict[str, list]] = { + "posterior": {}, + "sample_stats": {}, + } + self._buffered_draws = 0 + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk > 0 + self._posterior = zarr.open_group( + store, synchronizer=synchronizer, path="posterior", mode="a" + ) + if self.unconstrained_variables: + self._unconstrained_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" + ) + self._buffers["unconstrained_posterior"] = {} + self._sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="sample_stats", mode="a" + ) + self._sampling_state = zarr.open_group( + store, synchronizer=synchronizer, path="_sampling_state", mode="a" + ) + self.stats_bijection = stats_bijection + + def link_stepper(self, step_method: BlockedStep | CompoundStep): + """Provide a reference to the step method used during sampling. + + This reference can be used to facilite writing the stepper's sampling state + each time the samples are flushed into the storage. + """ + self._step_method = step_method + + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + self.chain = chain + self.total_draws = draws + self.draws_until_flush = min([self.draws_per_chunk, draws - self.draw_idx]) + self.clear_buffers() + + def clear_buffers(self): + for group in self._buffers: + self._buffers[group] = {} + self._buffered_draws = 0 + + def buffer(self, group, var_name, value): + buffer = self._buffers[group] + if var_name not in buffer: + buffer[var_name] = [] + buffer[var_name].append(value) + + def record( + self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] + ) -> bool | None: + """Record the step method's returned draw and stats. + + The draws and stats are first stored in an internal buffer. Once the buffer is + filled, the samples and stats are written (flushed) onto the desired zarr store. + + Returns + ------- + flushed : bool | None + Returns ``True`` only if the data was written onto the desired zarr store. + Any other time that the recorded draw and stats are written into the + internal buffer, ``None`` is returned. + + See Also + -------- + :meth:`~ZarrChain.flush` + """ + unconstrained_variables = self.unconstrained_variables + for var_name, var_value in zip(self.varnames, self.fn(draw)): + if var_name in unconstrained_variables: + self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) + else: + self.buffer(group="posterior", var_name=var_name, value=var_value) + for var_name, var_value in self.stats_bijection.map(stats).items(): + self.buffer(group="sample_stats", var_name=var_name, value=var_value) + self._buffered_draws += 1 + if self._buffered_draws == self.draws_until_flush: + self.flush() + return True + return None + + def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None): + """Record the sampling state information to the store's ``_sampling_state`` group. + + The sampling state includes the number of draws taken so far (``draw_idx``) and + the step method's ``sampling_state``. + + Parameters + ---------- + step : BlockedStep | CompoundStep | None + The step method from which to take the ``sampling_state``. If ``None``, + the ``step`` is taken to be the step method that was linked to the + ``ZarrChain`` when calling :meth:`~ZarrChain.link_stepper`. If this method was never + called, no step method ``sampling_state`` information is stored in the + chain. + """ + if step is None: + step = self._step_method + if step is not None: + self.store_sampling_state(step.sampling_state) + self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) + + def store_sampling_state(self, sampling_state): + self._sampling_state.sampling_state.set_coordinate_selection( + self.chain, np.array([sampling_state], dtype="object") + ) + + def flush(self): + """Write the data stored in the internal buffer to the desired zarr store. + + After writing the draws and stats returned by each step of the step method, + the :meth:`~ZarrChain.record_sampling_state` is called, the internal buffer is cleared and + the number of steps until the next flush is determined. + """ + chain = self.chain + draw_slice = slice(self.draw_idx, self.draw_idx + self.draws_until_flush) + for group_name, buffer in self._buffers.items(): + group = getattr(self, f"_{group_name}") + for var_name, var_value in buffer.items(): + group[var_name].set_orthogonal_selection( + (chain, draw_slice), + np.stack(var_value), + ) + self.draw_idx += self.draws_until_flush + self.record_sampling_state() + self.clear_buffers() + self.draws_until_flush = min([self.draws_per_chunk, self.total_draws - self.draw_idx]) + + +FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None +DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = { + np.floating: np.nan, + np.integer: 0, + np.bool_: False, + np.str_: "", + np.datetime64: np.datetime64(0, "Y"), + np.timedelta64: np.timedelta64(0, "Y"), +} + + +def get_initial_fill_value_and_codec( + dtype: Any, +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: + _dtype = np.dtype(dtype) + fill_value: FILL_VALUE_TYPE = None + codec = None + try: + fill_value = DEFAULT_FILL_VALUES[_dtype] + except KeyError: + for key in DEFAULT_FILL_VALUES: + if np.issubdtype(_dtype, key): + fill_value = DEFAULT_FILL_VALUES[key] + break + else: + codec = numcodecs.Pickle() + return fill_value, _dtype, codec + + +class ZarrTrace: + """Object that stores and enables access to MCMC draws stored in a :class:`zarr.hierarchy.Group` objects. + + This class creats a zarr hierarchy to represent the sampling information which is + intended to mimic :class:`arviz.InferenceData`. The hierarchy looks like this: + + | root + | |--> constant_data + | |--> observed_data + | |--> posterior + | |--> unconstrained_posterior + | |--> sample_stats + | |--> warmup_posterior + | |--> warmup_unconstrained_posterior + | |--> warmup_sample_stats + | |--> _sampling_state + + The root group is created when the ``ZarrTrace`` object is initialized. The rest of + the groups are created once :meth:`~ZarrChain.init_trace` is called with a few exceptions: + unconstrained_posterior is only created if ``include_transformed = True``, and the + groups prefixed with ``warmup_`` are created only after calling + :meth:`~ZarrTrace.split_warmup_groups`. + + Since ``ZarrTrace`` objects are intended to be as close to + :class:`arviz.InferenceData` objects as possible, the groups store the dimension + and coordinate information following the `xarray zarr standard `_. + + Parameters + ---------- + store : zarr.storage.BaseStore | collections.abc.MutableMapping | None + The store object where the zarr groups and arrays will be stored and read from. + Any zarr compatible storage object works. Keep in mind that if ``None`` is + provided, a :class:`zarr.storage.MemoryStore` will be used, which means that + information won't be visible to other processes and won't persist after the + ``ZarrTrace`` life-cycle ends. If you want to have persistent storage, please + use one of the multiple disk backed zarr storage options, e.g. + :class:`~zarr.storage.DirectoryStore` or :class:`~zarr.storage.ZipStore`. + synchronizer : zarr.sync.Synchronizer | None + The synchronizer to use for the underlying zarr arrays. + compressor : numcodec.abc.Codec | None | pymc.util.UNSET + The compressor to use for the underlying zarr arrays. If ``None``, no compressor + is used. If ``UNSET``, zarr's default compressor is used. + draws_per_chunk : int + The number of draws that make up a chunk in the variable's posterior array. + Each variable's array shape is set to ``(n_chains, n_draws, *rv_shape)``, but + the chunks are set to ``(1, draws_per_chunk, *rv_shape)``. This means that each + chain will have it's own chunk to read or write to, allowing for concurrent + write operations of different chains not to interfere with each other, and that + multiple draws can belong to the same chunk. The variable's core dimension + however, will never be split across different chunks. + include_transformed : bool + If ``True``, the transformed, unconstrained value variables are included in the + storage group. + + Notes + ----- + ``ZarrTrace`` objects represent the storage information. If the underlying store + persists on disk or over the network (e.g. with a :class:`zarr.storage.FSStore`) + multiple process will be able to concurrently access the same storage and read or + write to it. + + The intended division of labour is for ``ZarrTrace`` to handle the creation and + management of the zarr group and storage objects and arrays, and for individual + :class:`~.ZarrChain` objects to handle recording MCMC samples to the trace. This + division was chosen to stay close to the existing `pymc.backends.base.MultiTrace` + and `pymc.backends.ndarray.NDArray` way of working with the existing samplers. + + One extra feature of ``ZarrTrace`` is that it enables direct access to any array's + metadata. ``ZarrTrace`` takes advantage of this to tag arrays as ``deterministic`` + or ``freeRV`` depending on what kind of variable they were in the defining model. + + See Also + -------- + :class:`~pymc.backends.zarr.ZarrChain` + """ + + def __init__( + self, + store: BaseStore | MutableMapping | None = None, + synchronizer: Synchronizer | None = None, + compressor: Codec | None | _UnsetType = UNSET, + draws_per_chunk: int = 1, + include_transformed: bool = False, + ): + if not _zarr_available: + raise RuntimeError("You must install zarr to be able to create ZarrTrace instances") + self.synchronizer = synchronizer + if compressor is UNSET: + compressor = default_compressor + self.compressor = compressor + self.root = zarr.group( + store=store, + overwrite=True, + synchronizer=synchronizer, + ) + + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk >= 1 + + self.include_transformed = include_transformed + + self._is_base_setup = False + + def groups(self) -> list[str]: + return [str(group_name) for group_name, _ in self.root.groups()] + + @property + def posterior(self) -> zarr.Group: + return self.root.posterior + + @property + def unconstrained_posterior(self) -> zarr.Group: + return self.root.unconstrained_posterior + + @property + def sample_stats(self) -> zarr.Group: + return self.root.sample_stats + + @property + def constant_data(self) -> zarr.Group: + return self.root.constant_data + + @property + def observed_data(self) -> zarr.Group: + return self.root.observed_data + + @property + def _sampling_state(self) -> zarr.Group: + return self.root._sampling_state + + def init_trace( + self, + chains: int, + draws: int, + tune: int, + step: BlockedStep | CompoundStep, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + ): + """Initialize the trace groups and arrays. + + This function creates and fills with default values the groups below the + ``ZarrTrace.root`` group. It creates the ``constant_data``, ``observed_data``, + ``posterior``, ``unconstrained_posterior`` (if ``include_transformed = True``), + ``sample_stats``, and ``_sampling_state`` zarr groups, and all of the relevant + arrays that must be stored there. + + Every array in the posterior and sample stats groups will have the + (chains, tune + draws) batch dimensions to the left of the core dimensions of + the model's random variable or the step method's stat shape. The warmup (tuning + draws) and the posterior samples are split at a later stage, once + :meth:`~ZarrTrace.split_warmup_groups` is called. + + After the creation if the zarr hierarchies, it initializes the list of + :class:`~pymc.backends.zarr.Zarrchain` instances (one for each chain) under the + ``straces`` attribute. These objects serve as the interface to record draws and + samples generated by the step methods for each chain. + + Parameters + ---------- + chains : int + The number of chains to use to initialize the arrays. + draws : int + The number of posterior draws to use to initialize the arrays. + tune : int + The number of tuning steps to use to initialize the arrays. + step : pymc.step_methods.compound.BlockedStep | pymc.step_methods.compound.CompoundStep + The step method that will be used to generate the draws and stats. + model : pymc.model.core.Model | None + If None, the model is taken from the ``with`` context. + vars : Sequence[TensorVariable] | None + Sampling values will be stored for these variables. If ``None``, + ``model.unobserved_RVs`` is used. + test_point : dict[str, numpy.ndarray] | None + This is not used and is a product of the inheritance of :class:`ZarrChain` + from :class:`~.BaseTrace`, which uses it to determine the shape and dtype + of `vars`. + """ + if self._is_base_setup: + raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover + model = modelcontext(model) + self.model = model + self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) + if vars is None: + vars = model.unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" + self.varnames = get_default_varnames( + [var.name for var in vars], include_transformed=self.include_transformed + ) + self.vars = [var for var in vars if var.name in self.varnames] + + self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + + # Get variable shapes. Most backends will need this + # information. + if test_point is None: + test_point = model.initial_point() + var_values = list(zip(self.varnames, self.fn(test_point))) + self.var_dtype_shapes = { + var: (value.dtype, value.shape) + for var, value in var_values + if not is_transformed_name(var) + } + extra_var_attrs = { + var: { + "kind": "freeRV" + if is_transformed_name(var) or model[var] in model.free_RVs + else "deterministic" + } + for var in self.var_dtype_shapes + } + self.unc_var_dtype_shapes = { + var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var) + } + extra_unc_var_attrs = {var: {"kind": "freeRV"} for var in self.unc_var_dtype_shapes} + + self.create_group( + name="constant_data", + data_dict=find_constants(self.model), + ) + + self.create_group( + name="observed_data", + data_dict=find_observations(self.model), + ) + + # Create the posterior that includes warmup draws + self.init_group_with_empty( + group=self.root.create_group(name="posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=tune + draws, + extra_var_attrs=extra_var_attrs, + ) + + # Create the unconstrained posterior group that includes warmup draws + if self.include_transformed and self.unc_var_dtype_shapes: + self.init_group_with_empty( + group=self.root.create_group(name="unconstrained_posterior", overwrite=True), + var_dtype_and_shape=self.unc_var_dtype_shapes, + chains=chains, + draws=tune + draws, + extra_var_attrs=extra_unc_var_attrs, + ) + + # Create the sample stats that include warmup draws + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + self.init_group_with_empty( + group=self.root.create_group(name="sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=tune + draws, + ) + + self.init_sampling_state_group(tune=tune, chains=chains) + + self.straces = [ + ZarrChain( + store=self.root.store, + synchronizer=self.synchronizer, + model=self.model, + vars=self.vars, + test_point=test_point, + stats_bijection=StatsBijection(step.stats_dtypes), + draws_per_chunk=self.draws_per_chunk, + ) + for _ in range(chains) + ] + for chain, strace in enumerate(self.straces): + strace.setup(draws=tune + draws, chain=chain, sampler_vars=None) + + def split_warmup_groups(self): + """Split the warmup and standard groups. + + This method takes the entries in the arrays in the posterior, sample_stats + and unconstrained_posterior that happened in the tuning phase and moves them + into the warmup_ groups. If the ``warmup_posterior`` group already exists, then + nothing is done. + + See Also + -------- + :meth:`~ZarrTrace.split_warmup` + """ + if "warmup_posterior" not in self.groups(): + self.split_warmup("posterior", error_if_already_split=False) + self.split_warmup("sample_stats", error_if_already_split=False) + try: + self.split_warmup("unconstrained_posterior", error_if_already_split=False) + except KeyError: + pass + + @property + def tuning_steps(self): + try: + return int(self._sampling_state.tuning_steps.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no tuning step information available" + ) + + @property + def sampling_time(self): + try: + return float(self._sampling_state.sampling_time.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no sampling time information available" + ) + + @sampling_time.setter + def sampling_time(self, value): + self._sampling_state.sampling_time.set_basic_selection((), float(value)) + + def init_sampling_state_group(self, tune: int, chains: int): + state = self.root.create_group(name="_sampling_state", overwrite=True) + sampling_state = state.empty( + name="sampling_state", + overwrite=True, + shape=(chains,), + chunks=(1,), + dtype="object", + object_codec=numcodecs.Pickle(), + compressor=self.compressor, + ) + sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + draw_idx = state.array( + name="draw_idx", + overwrite=True, + data=np.zeros(chains, dtype="int"), + chunks=(1,), + dtype="int", + fill_value=-1, + compressor=self.compressor, + ) + draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.array( + name="tuning_steps", + data=tune, + overwrite=True, + dtype="int", + fill_value=0, + compressor=self.compressor, + ) + state.array( + name="sampling_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + state.array( + name="sampling_start_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + + chain = state.array( + name="chain", + data=np.arange(chains), + compressor=self.compressor, + ) + + chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.empty( + name="global_warnings", + dtype="object", + object_codec=numcodecs.Pickle(), + shape=(0,), + ) + + def init_group_with_empty( + self, + group: zarr.Group, + var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], + chains: int, + draws: int, + extra_var_attrs: dict | None = None, + ) -> zarr.Group: + group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} + for name, (_dtype, shape) in var_dtype_and_shape.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) + shape = shape or () + array = group.full( + name=name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, draws, *shape), + chunks=(1, self.draws_per_chunk, *shape), + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i, shape_i in enumerate(shape): + dim = f"{name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(shape_i, dtype="int") + dims = ("chain", "draw", *dims) + attrs = extra_var_attrs[name] if extra_var_attrs is not None else {} + attrs.update({"_ARRAY_DIMENSIONS": dims}) + array.attrs.update(attrs) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: + group: zarr.Group | None = None + if data_dict: + group_coords = {} + group = self.root.create_group(name=name, overwrite=True) + for var_name, var_value in data_dict.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype) + array = group.array( + name=var_name, + data=var_value, + fill_value=fill_value, + dtype=dtype, + object_codec=object_codec, + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[var_name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i in range(var_value.ndim): + dim = f"{var_name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(var_value.shape[i], dtype="int") + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def split_warmup(self, group_name: str, error_if_already_split: bool = True): + """Split the arrays of a group into the warmup and regular groups. + + This function takes the first ``self.tuning_steps`` draws of supplied + ``group_name`` and moves them into a new zarr group called + ``f"warmup_{group_name}"``. + + Parameters + ---------- + group_name : str + The name of the group that should be split. + error_if_already_split : bool + If ``True`` and if the ``f"warmup_{group_name}"`` group already exists in + the root hierarchy, a ``ValueError`` is raised. If this flag is ``False`` + but the warmup group already exists, the contents of that group are + overwritten. + """ + if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { + group_name for group_name, _ in self.root.groups() + }: + raise RuntimeError(f"Warmup data for {group_name} has already been split") + posterior_group = self.root[group_name] + tune = self.tuning_steps + warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) + if tune == 0: + try: + self.root.pop(f"{WARMUP_TAG}{group_name}") + except KeyError: + pass + return + for name, array in posterior_group.arrays(): + array_attrs = array.attrs.asdict() + if name == "draw": + warmup_array = warmup_group.array( + name="draw", + data=np.arange(tune), + dtype="int", + compressor=self.compressor, + ) + posterior_array = posterior_group.array( + name=name, + data=np.arange(len(array) - tune), + dtype="int", + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + else: + dims = array.attrs["_ARRAY_DIMENSIONS"] + warmup_idx: slice | tuple[slice, slice] + if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: + must_overwrite_posterior = True + warmup_idx = (slice(None), slice(None, tune, None)) + posterior_idx = (slice(None), slice(tune, None, None)) + else: + must_overwrite_posterior = False + warmup_idx = slice(None) + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) + warmup_array = warmup_group.array( + name=name, + data=array[warmup_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + compressor=self.compressor, + ) + if must_overwrite_posterior: + posterior_array = posterior_group.array( + name=name, + data=array[posterior_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + warmup_array.attrs.update(array_attrs) + + def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: + """Convert ``ZarrTrace`` to :class:`~.arviz.InferenceData`. + + This converts all the groups in the ``ZarrTrace.root`` hierarchy into an + ``InferenceData`` object. The only exception is that ``_sampling_state`` is + excluded. + + Parameters + ---------- + save_warmup : bool + If ``True``, all of the warmup groups are stored in the inference data + object. + + Notes + ----- + ``xarray`` and in turn ``arviz`` require the zarr groups to have consolidated + metadata. To achieve this, a new consolidated store is constructed by calling + :func:`zarr.consolidate_metadata` on the root's store. This means that the + returned ``InferenceData`` object will operate on a different storage unit + than the calling ``ZarrTrace``, so future changes to the ``ZarrTrace`` won't be + automatically reflected in the returned ``InferenceData`` object. + """ + self.split_warmup_groups() + # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata + consolidated_root = zarr.consolidate_metadata(self.root.store) + # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view + # we need to actually grab the underlying store so that xarray doesn't produce completely + # empty arrays + store = consolidated_root.store.store + groups = {} + try: + global_attrs = { + "tuning_steps": self.tuning_steps, + "sampling_time": self.sampling_time, + } + except AttributeError: + global_attrs = {} # pragma: no cover + for name, _ in self.root.groups(): + if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)): + continue + data = xr.open_zarr(store, group=name, mask_and_scale=False) + attrs = {**data.attrs, **global_attrs} + data.attrs = make_attrs(attrs=attrs, library=pymc) + groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data + return az.InferenceData(**groups) diff --git a/requirements-dev.txt b/requirements-dev.txt index 56f7f964fc..e7e3644aae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,4 @@ threadpoolctl>=3.1.0 types-cachetools typing-extensions>=3.7.4 watermark +zarr>=2.5.0,<3 diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py new file mode 100644 index 0000000000..7ef3c472f5 --- /dev/null +++ b/tests/backends/test_zarr.py @@ -0,0 +1,359 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +from dataclasses import asdict + +import numpy as np +import pytest +import zarr + +import pymc as pm + +from pymc.backends.zarr import ZarrTrace +from pymc.stats.convergence import SamplerWarning +from pymc.step_methods import NUTS, CompoundStep, Metropolis +from pymc.step_methods.state import equal_dataclass_values +from tests.helpers import equal_sampling_states + + +@pytest.fixture(scope="module") +def model(): + time_int = np.array([np.timedelta64(np.timedelta64(i, "h"), "ns") for i in range(25)]) + coords = { + "dim_int": range(3), + "dim_str": ["A", "B"], + "dim_time": np.datetime64("2024-10-16") + time_int, + "dim_interval": time_int, + } + rng = np.random.default_rng(42) + with pm.Model(coords=coords) as model: + data1 = pm.Data("data1", np.ones(3, dtype="bool"), dims=["dim_int"]) + data2 = pm.Data("data2", np.ones(3, dtype="bool")) + time = pm.Data("time", time_int / np.timedelta64(1, "h"), dims="dim_time") + + a = pm.Normal("a", shape=(len(coords["dim_int"]), len(coords["dim_str"]))) + b = pm.Normal("b", dims=["dim_int", "dim_str"]) + c = pm.Deterministic("c", a + b, dims=["dim_int", "dim_str"]) + + d = pm.LogNormal("d", dims="dim_time") + e = pm.Deterministic("e", (time + d)[:, None] + c[0], dims=["dim_interval", "dim_str"]) + + obs = pm.Normal( + "obs", + mu=e, + observed=rng.normal(size=(len(coords["dim_time"]), len(coords["dim_str"]))), + dims=["dim_time", "dim_str"], + ) + + return model + + +@pytest.fixture(params=[True, False]) +def include_transformed(request): + return request.param + + +@pytest.fixture(params=["frequent_writes", "sparse_writes"]) +def draws_per_chunk(request): + spec = { + "frequent_writes": 1, + "sparse_writes": 7, + } + return spec[request.param] + + +@pytest.fixture(params=["single_step", "compound_step"]) +def model_step(request, model): + rng = np.random.default_rng(42) + with model: + if request.param == "single_step": + step = NUTS(rng=rng) + else: + rngs = rng.spawn(2) + step = CompoundStep( + [ + Metropolis(vars=model["a"], rng=rngs[0]), + NUTS(vars=[rv for rv in model.value_vars if rv.name != "a"], rng=rngs[1]), + ] + ) + return step + + +def test_record(model, model_step, include_transformed, draws_per_chunk): + store = zarr.MemoryStore() + trace = ZarrTrace( + store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + draws = 5 + tune = 5 + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + # Assert that init was successful + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "constant_data", + "observed_data", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Record samples from the ZarrChain + manually_collected_warmup_draws = [] + manually_collected_warmup_stats = [] + manually_collected_draws = [] + manually_collected_stats = [] + point = model.initial_point() + for draw in range(tune + draws): + tuning = draw < tune + if not tuning: + model_step.stop_tuning() + point, stats = model_step.step(point) + if tuning: + manually_collected_warmup_draws.append(point) + manually_collected_warmup_stats.append(stats) + else: + manually_collected_draws.append(point) + manually_collected_stats.append(stats) + trace.straces[0].record(point, stats) + trace.straces[0].record_sampling_state(model_step) + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Assert split warmup + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "warmup_sample_stats", + "warmup_posterior", + "constant_data", + "observed_data", + } + if include_transformed: + trace.split_warmup("unconstrained_posterior") + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + # trace.consolidate() + + # Assert observed data is correct + assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"} + assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"] + np.testing.assert_array_equal(trace.observed_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.observed_data.dim_str[:], model.coords["dim_str"]) + + # Assert constant data is correct + assert set(dict(trace.constant_data.arrays())) == { + "data1", + "data2", + "data2_dim_0", + "time", + "dim_time", + "dim_int", + } + assert list(trace.constant_data.data1.attrs["_ARRAY_DIMENSIONS"]) == ["dim_int"] + assert list(trace.constant_data.data2.attrs["_ARRAY_DIMENSIONS"]) == ["data2_dim_0"] + assert list(trace.constant_data.time.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time"] + np.testing.assert_array_equal(trace.constant_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.constant_data.dim_int[:], model.coords["dim_int"]) + + # Assert unconstrained posterior has correct shapes and kinds + assert {rv.name for rv in model.free_RVs + model.deterministics} <= set( + dict(trace.posterior.arrays()) + ) + if include_transformed: + assert {"d_log__", "chain", "draw", "d_log___dim_0"} == set( + dict(trace.unconstrained_posterior.arrays()) + ) + assert list(trace.unconstrained_posterior.d_log__.attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + "d_log___dim_0", + ] + assert trace.unconstrained_posterior.d_log__.attrs["kind"] == "freeRV" + np.testing.assert_array_equal(trace.unconstrained_posterior.chain, np.arange(1)) + np.testing.assert_array_equal(trace.unconstrained_posterior.draw, np.arange(draws)) + np.testing.assert_array_equal( + trace.unconstrained_posterior.d_log___dim_0, np.arange(len(model.coords["dim_time"])) + ) + + # Assert posterior has correct shapes and kinds + posterior_dims = set() + for kind, rv_name in [ + (kind, rv.name) + for kind, rv in itertools.chain( + itertools.zip_longest([], model.free_RVs, fillvalue="freeRV"), + itertools.zip_longest([], model.deterministics, fillvalue="deterministic"), + ) + ]: + if rv_name == "a": + expected_dims = ["a_dim_0", "a_dim_1"] + else: + expected_dims = model.named_vars_to_dims[rv_name] + posterior_dims |= set(expected_dims) + assert list(trace.posterior[rv_name].attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + *expected_dims, + ] + assert trace.posterior[rv_name].attrs["kind"] == kind + for posterior_dim in posterior_dims: + try: + model_coord = model.coords[posterior_dim] + except KeyError: + model_coord = { + "a_dim_0": np.arange(len(model.coords["dim_int"])), + "a_dim_1": np.arange(len(model.coords["dim_str"])), + "chain": np.arange(1), + "draw": np.arange(draws), + }[posterior_dim] + np.testing.assert_array_equal(trace.posterior[posterior_dim][:], model_coord) + + # Assert sample stats have correct shape + stats_bijection = trace.straces[0].stats_bijection + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var in trace.posterior.arrays(): + assert np.array_equal(trace.posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert manually collected warmup samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_warmup_draws, manually_collected_warmup_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["warmup_unconstrained_posterior"] + else: + posterior = trace.root["warmup_posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["warmup_sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert manually collected posterior samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["unconstrained_posterior"] + else: + posterior = trace.root["posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert sampling_state is correct + assert list(trace._sampling_state.draw_idx[:]) == [draws + tune] + assert equal_sampling_states( + trace._sampling_state.sampling_state[0], + model_step.sampling_state, + ) + + # Assert to inference data returns the expected groups + idata = trace.to_inferencedata(save_warmup=True) + expected_groups = { + "posterior", + "constant_data", + "observed_data", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert set(idata.groups()) == expected_groups + for group in idata.groups(): + for name, value in itertools.chain( + idata[group].data_vars.items(), idata[group].coords.items() + ): + try: + array = getattr(trace, group)[name][:] + except AttributeError: + array = trace.root[group][name][:] + if "sample_stats" in group and "warning" in name: + continue + np.testing.assert_array_equal(array, value) + + +@pytest.mark.parametrize("tune", [0, 5, 10]) +def test_split_warmup(tune, model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 10 - tune + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + assert len(trace.root.posterior.draw) == draws + assert len(trace.root.sample_stats.draw) == draws + if tune == 0: + with pytest.raises(KeyError): + trace.root["warmup_posterior"] + else: + assert len(trace.root["warmup_posterior"].draw) == tune + assert len(trace.root["warmup_sample_stats"].draw) == tune + + with pytest.raises(RuntimeError): + trace.split_warmup("posterior") + + for var_name, posterior_array in trace.posterior.arrays(): + dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert posterior_array.shape[1] == draws + assert trace.root["warmup_posterior"][var_name].shape[1] == tune + for var_name, sample_stats_array in trace.sample_stats.arrays(): + dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert sample_stats_array.shape[1] == draws + assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune From 147b92eefa1604fe9872317f6525d7cb0f30c52a Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 23 Oct 2024 11:24:45 +0200 Subject: [PATCH 05/25] Integrate ZarrTrace into pymc.sample --- pymc/backends/__init__.py | 15 ++++- pymc/sampling/mcmc.py | 81 +++++++++++++++++++++++--- tests/backends/test_zarr.py | 111 ++++++++++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+), 10 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index cd007cf3c0..882412ce2d 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,6 +72,7 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray +from pymc.backends.zarr import ZarrTrace from pymc.blocking import PointType from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep @@ -120,15 +121,27 @@ def _init_trace( def init_traces( *, - backend: TraceOrBackend | None, + backend: TraceOrBackend | ZarrTrace | None, chains: int, expected_length: int, step: BlockedStep | CompoundStep, initial_point: PointType, model: Model, trace_vars: list[TensorVariable] | None = None, + tune: int = 0, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" + if isinstance(backend, ZarrTrace): + backend.init_trace( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=trace_vars, + test_point=initial_point, + ) + return None, backend.straces if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( backend=backend, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index b2d643a5f1..85029c899b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -50,6 +50,7 @@ find_observations, ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains +from pymc.backends.zarr import ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -503,7 +504,7 @@ def sample( model: Model | None = None, compile_kwargs: dict | None = None, **kwargs, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: r"""Draw samples from the posterior using the given step methods. Multiple step methods are supported via compound step methods. @@ -570,7 +571,13 @@ def sample( Number of iterations of initializer. Only works for 'ADVI' init methods. trace : backend, optional A backend instance or None. - If None, the NDArray backend is used. + If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects + is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance, + the drawn samples will be written onto the desired storage while sampling is + on-going. This means sampling runs that, for whatever reason, die in the middle + of their execution will write the partial results onto the storage. If the + storage persist on disk, these results should be available even after a server + crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information. discard_tuned_samples : bool Whether to discard posterior samples of the tune interval. compute_convergence_checks : bool, default=True @@ -607,8 +614,12 @@ def sample( Returns ------- - trace : pymc.backends.base.MultiTrace or arviz.InferenceData - A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples. + trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData + A ``MultiTrace``, :class:`~arviz.InferenceData` or + :class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A + ``ZarrTrace`` is only returned if the supplied ``trace`` argument is a + ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for + the benefits this backend provides. Notes ----- @@ -741,7 +752,7 @@ def joined_blas_limiter(): rngs = get_random_generator(random_seed).spawn(chains) random_seed_list = [rng.integers(2**30) for rng in rngs] - if not discard_tuned_samples and not return_inferencedata: + if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace): warnings.warn( "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" @@ -852,6 +863,7 @@ def joined_blas_limiter(): trace_vars=trace_vars, initial_point=initial_points[0], model=model, + tune=tune, ) sample_args = { @@ -934,7 +946,7 @@ def joined_blas_limiter(): # into a function to make it easier to test and refactor. return _sample_return( run=run, - traces=traces, + traces=trace if isinstance(trace, ZarrTrace) else traces, tune=tune, t_sampling=t_sampling, discard_tuned_samples=discard_tuned_samples, @@ -949,7 +961,7 @@ def joined_blas_limiter(): def _sample_return( *, run: RunType | None, - traces: Sequence[IBaseTrace], + traces: Sequence[IBaseTrace] | ZarrTrace, tune: int, t_sampling: float, discard_tuned_samples: bool, @@ -958,18 +970,69 @@ def _sample_return( keep_warning_stat: bool, idata_kwargs: dict[str, Any], model: Model, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: """Pick/slice chains, run diagnostics and convert to the desired return type. Final step of `pm.sampler`. """ + if isinstance(traces, ZarrTrace): + # Split warmup from posterior samples + traces.split_warmup_groups() + + # Set sampling time + traces.sampling_time = t_sampling + + # Compute number of actual draws per chain + total_draws_per_chain = traces._sampling_state.draw_idx[:] + n_chains = len(traces.straces) + desired_tune = traces.tuning_steps + desired_draw = len(traces.posterior.draw) + tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune) + draws_per_chain = total_draws_per_chain - tuning_steps_per_chain + + total_n_tune = tuning_steps_per_chain.sum() + total_draws = draws_per_chain.sum() + + _log.info( + f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " + f"took {t_sampling:.0f} seconds." + ) + + if compute_convergence_checks or return_inferencedata: + idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples) + log_likelihood = idata_kwargs.pop("log_likelihood", False) + if log_likelihood: + from pymc.stats.log_density import compute_log_likelihood + + idata = compute_log_likelihood( + idata, + var_names=None if log_likelihood is True else log_likelihood, + extend_inferencedata=True, + model=model, + sample_dims=["chain", "draw"], + progressbar=False, + ) + if compute_convergence_checks: + warns = run_convergence_checks(idata, model) + for warn in warns: + traces._sampling_state.global_warnings.append(np.array([warn])) + log_warnings(warns) + + if return_inferencedata: + # By default we drop the "warning" stat which contains `SamplerWarning` + # objects that can not be stored with `.to_netcdf()`. + if not keep_warning_stat: + return drop_warning_stat(idata) + return idata + return traces + # Pick and slice chains to keep the maximum number of samples if discard_tuned_samples: traces, length = _choose_chains(traces, tune) else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index 7ef3c472f5..9644dbfcd4 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -19,6 +19,8 @@ import pytest import zarr +from arviz import InferenceData + import pymc as pm from pymc.backends.zarr import ZarrTrace @@ -357,3 +359,112 @@ def test_split_warmup(tune, model, model_step, include_transformed): if len(dims) >= 2 and dims[1] == "draw": assert sample_stats_array.shape[1] == draws assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune + + +@pytest.fixture(scope="function", params=[True, False]) +def discard_tuned_samples(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def return_inferencedata(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def keep_warning_stat(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def parallel(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def log_likelihood(request): + return request.param + + +def test_sample( + model, + model_step, + include_transformed, + discard_tuned_samples, + return_inferencedata, + keep_warning_stat, + parallel, + log_likelihood, + draws_per_chunk, +): + if not return_inferencedata and not log_likelihood: + pytest.skip( + reason="log_likelihood is only computed if an inference data object is returned" + ) + store = zarr.MemoryStore() + trace = ZarrTrace( + store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + tune = 2 + draws = 3 + if parallel: + chains = 2 + cores = 2 + else: + chains = 1 + cores = 1 + with model: + out_trace = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + trace=trace, + step=model_step, + discard_tuned_samples=discard_tuned_samples, + return_inferencedata=return_inferencedata, + keep_warning_stat=keep_warning_stat, + idata_kwargs={"log_likelihood": log_likelihood}, + ) + + if not return_inferencedata: + assert isinstance(out_trace, ZarrTrace) + assert out_trace.root.store is trace.root.store + else: + assert isinstance(out_trace, InferenceData) + + expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"} + if include_transformed: + expected_groups |= {"unconstrained_posterior"} + if not return_inferencedata or not discard_tuned_samples: + expected_groups |= {"warmup_posterior", "warmup_sample_stats"} + if include_transformed: + expected_groups |= {"warmup_unconstrained_posterior"} + if not return_inferencedata: + expected_groups |= {"_sampling_state"} + elif log_likelihood: + expected_groups |= {"log_likelihood"} + assert set(out_trace.groups()) == expected_groups + + if return_inferencedata: + warning_stat = ( + "sampler_1__warning" if isinstance(model_step, CompoundStep) else "sampler_0__warning" + ) + if keep_warning_stat: + assert warning_stat in out_trace.sample_stats + else: + assert warning_stat not in out_trace.sample_stats + + # Assert that all variables have non empty samples (not NaNs) + if return_inferencedata: + assert all( + (not np.any(np.isnan(v))) and v.shape[:2] == (chains, draws) + for v in out_trace.posterior.data_vars.values() + ) + else: + dimensions = {*model.coords, "a_dim_0", "a_dim_1", "chain", "draw"} + assert all( + (not np.any(np.isnan(v[:]))) and v.shape[:2] == (chains, draws) + for name, v in out_trace.posterior.arrays() + if name not in dimensions + ) From 35cdfa674e8d78a965516ac5c256f0c9d2f73c8d Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 6 Nov 2024 15:42:59 +0100 Subject: [PATCH 06/25] Write sampling state periodically --- pymc/sampling/mcmc.py | 30 ++++++++++++- pymc/sampling/parallel.py | 47 +++++++++++++++++++ pymc/sampling/population.py | 35 +++++++++++++-- tests/backends/test_zarr.py | 90 ++++++++++++++++++++++++++++++++----- 4 files changed, 186 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 85029c899b..83a4835737 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -26,6 +26,7 @@ Any, Literal, TypeAlias, + cast, overload, ) @@ -40,6 +41,7 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol +from zarr.storage import MemoryStore import pymc as pm @@ -50,7 +52,7 @@ find_observations, ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains -from pymc.backends.zarr import ZarrTrace +from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -1275,6 +1277,8 @@ def _iter_sample( step.set_rng(rng) point = start + if isinstance(trace, ZarrChain): + trace.link_stepper(step) try: step.tune = bool(tune) @@ -1297,12 +1301,18 @@ def _iter_sample( yield diverging except KeyboardInterrupt: + if isinstance(trace, ZarrChain): + trace.record_sampling_state(step=step) trace.close() raise except BaseException: + if isinstance(trace, ZarrChain): + trace.record_sampling_state(step=step) trace.close() raise else: + if isinstance(trace, ZarrChain): + trace.record_sampling_state(step=step) trace.close() @@ -1361,6 +1371,19 @@ def _mp_sample( # We did draws += tune in pm.sample draws -= tune + zarr_chains: list[ZarrChain] | None = None + zarr_recording = False + if all(isinstance(trace, ZarrChain) for trace in traces): + if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore): + warnings.warn( + "Parallel sampling with MemoryStore zarr store wont write the processes " + "step method sampling state. If you wish to be able to access the step " + "method sampling state, please use a different storage backend, e.g. " + "DirectoryStore or ZipStore" + ) + else: + zarr_chains = cast(list[ZarrChain], traces) + zarr_recording = True sampler = ps.ParallelSampler( draws=draws, @@ -1374,13 +1397,16 @@ def _mp_sample( progressbar_theme=progressbar_theme, blas_cores=blas_cores, mp_ctx=mp_ctx, + zarr_chains=zarr_chains, ) try: try: with sampler: for draw in sampler: strace = traces[draw.chain] - strace.record(draw.point, draw.stats) + if not zarr_recording: + # Zarr recording happens in each process + strace.record(draw.point, draw.stats) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 67417e0d8f..794763e6e1 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -22,6 +22,7 @@ from collections import namedtuple from collections.abc import Sequence +from typing import cast import cloudpickle import numpy as np @@ -31,6 +32,7 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits +from pymc.backends.zarr import ZarrChain from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( @@ -104,6 +106,9 @@ def __init__( tune: int, rng_state: RandomGeneratorState, blas_cores, + chain: int, + zarr_chains: list[ZarrChain] | bytes | None = None, + zarr_chains_is_pickled: bool = False, ): # For some strange reason, spawn multiprocessing doesn't copy the rng # seed sequence, so we have to rebuild it from scratch @@ -111,6 +116,15 @@ def __init__( self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled + self.chain = chain + self._zarr_recording = False + self._zarr_chain: ZarrChain | None = None + if zarr_chains_is_pickled: + self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain] + elif zarr_chains is not None: + self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain] + self._zarr_recording = self._zarr_chain is not None + self._shared_point = shared_point self._rng = rng self._draws = draws @@ -135,6 +149,7 @@ def run(self): # We do not create this in __init__, as pickling this # would destroy the shared memory. self._unpickle_step_method() + self._link_step_to_zarrchain() self._point = self._make_numpy_refs() self._start_loop() except KeyboardInterrupt: @@ -148,6 +163,10 @@ def run(self): finally: self._msg_pipe.close() + def _link_step_to_zarrchain(self): + if self._zarr_recording: + self._zarr_chain.link_stepper(self._step_method) + def _wait_for_abortion(self): while True: msg = self._recv_msg() @@ -170,6 +189,7 @@ def _recv_msg(self): return self._msg_pipe.recv() def _start_loop(self): + zarr_recording = self._zarr_recording self._step_method.set_rng(self._rng) draw = 0 @@ -199,6 +219,8 @@ def _start_loop(self): if msg[0] == "abort": raise KeyboardInterrupt() elif msg[0] == "write_next": + if zarr_recording: + self._zarr_chain.record(point, stats) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) @@ -225,6 +247,8 @@ def __init__( start: dict[str, np.ndarray], blas_cores, mp_ctx, + zarr_chains: list[ZarrChain] | None = None, + zarr_chains_pickled: bytes | None = None, ): self.chain = chain process_name = f"worker_chain_{chain}" @@ -247,6 +271,16 @@ def __init__( self._readable = True self._num_samples = 0 + zarr_chains_send: list[ZarrChain] | bytes | None = None + if zarr_chains_pickled is not None: + zarr_chains_send = zarr_chains_pickled + elif zarr_chains is not None: + if mp_ctx.get_start_method() == "spawn": + raise ValueError( + "please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'" + ) + zarr_chains_send = zarr_chains + if step_method_pickled is not None: step_method_send = step_method_pickled else: @@ -270,6 +304,9 @@ def __init__( tune, get_state_from_generator(rng), blas_cores, + self.chain, + zarr_chains_send, + zarr_chains_pickled is not None, ), ) self._process.start() @@ -392,6 +429,7 @@ def __init__( progressbar_theme: Theme | None = default_progress_theme, blas_cores: int | None = None, mp_ctx=None, + zarr_chains: list[ZarrChain] | None = None, ): if any(len(arg) != chains for arg in [rngs, start_points]): raise ValueError(f"Number of rngs and start_points must be {chains}.") @@ -412,8 +450,15 @@ def __init__( mp_ctx = multiprocessing.get_context(mp_ctx) step_method_pickled = None + zarr_chains_pickled = None + self.zarr_recording = False + if zarr_chains is not None: + assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains) + self.zarr_recording = True if mp_ctx.get_start_method() != "fork": step_method_pickled = cloudpickle.dumps(step_method, protocol=-1) + if zarr_chains is not None: + zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1) self._samplers = [ ProcessAdapter( @@ -426,6 +471,8 @@ def __init__( start, blas_cores, mp_ctx, + zarr_chains=zarr_chains, + zarr_chains_pickled=zarr_chains_pickled, ) for chain, rng, start in zip(range(chains), rngs, start_points) ] diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 4e5a229960..b8a7ba593a 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -27,6 +27,7 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from pymc.backends.base import BaseTrace +from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext from pymc.stats.convergence import log_warning_stats @@ -36,6 +37,7 @@ PopulationArrayStepShared, StatsType, ) +from pymc.step_methods.compound import StepMethodState from pymc.step_methods.metropolis import DEMetropolis from pymc.util import CustomProgress @@ -81,6 +83,11 @@ def _sample_population( Show progress bars? (defaults to True) parallelize : bool Setting for multiprocess parallelization + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. """ warn_population_size( step=step, @@ -263,6 +270,9 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress): # receiving a None is the signal to exit if incoming is None: break + elif incoming == "sampling_state": + secondary_end.send((c, stepper.sampling_state)) + continue tune_stop, population = incoming if tune_stop: stepper.stop_tuning() @@ -307,6 +317,14 @@ def step(self, tune_stop: bool, population) -> list[tuple[PointType, StatsType]] updates.append(self._steppers[c].step(population[c])) return updates + def request_sampling_state(self, chain) -> StepMethodState: + if self.is_parallelized: + self._primary_ends[chain].send(("sampling_state",)) + _, sampling_state = self._primary_ends[chain].recv() + else: + sampling_state = self._steppers[chain].sampling_state + return sampling_state + def _prepare_iter_population( *, @@ -332,6 +350,11 @@ def _prepare_iter_population( Start points for each chain parallelize : bool Setting for multiprocess parallelization + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. tune : int Number of iterations to tune. rngs: sequence of random Generators @@ -411,8 +434,11 @@ def _iter_population( the helper object for (parallelized) stepping of chains steppers : list The step methods for each chain - traces : list - Traces for each chain + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. points : list population of chain states @@ -432,8 +458,11 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - strace.record(points[c], stats) + flushed = strace.record(points[c], stats) log_warning_stats(stats) + if flushed and isinstance(strace, ZarrChain): + sampling_state = popstep.request_sampling_state(c) + strace.store_sampling_state(sampling_state) # yield the state of all chains in parallel yield i except KeyboardInterrupt: diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index 9644dbfcd4..32f508ef1a 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -17,6 +17,7 @@ import numpy as np import pytest +import xarray as xr import zarr from arviz import InferenceData @@ -62,9 +63,9 @@ def model(): return model -@pytest.fixture(params=[True, False]) +@pytest.fixture(params=["include_transformed", "discard_transformed"]) def include_transformed(request): - return request.param + return request.param == "include_transformed" @pytest.fixture(params=["frequent_writes", "sparse_writes"]) @@ -94,7 +95,7 @@ def model_step(request, model): def test_record(model, model_step, include_transformed, draws_per_chunk): - store = zarr.MemoryStore() + store = zarr.TempStore() trace = ZarrTrace( store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk ) @@ -361,27 +362,31 @@ def test_split_warmup(tune, model, model_step, include_transformed): assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune -@pytest.fixture(scope="function", params=[True, False]) +@pytest.fixture(scope="function", params=["discard_tuning", "keep_tuning"]) def discard_tuned_samples(request): - return request.param + return request.param == "discard_tuning" -@pytest.fixture(scope="function", params=[True, False]) +@pytest.fixture(scope="function", params=["return_idata", "return_zarr"]) def return_inferencedata(request): - return request.param + return request.param == "return_idata" -@pytest.fixture(scope="function", params=[True, False]) +@pytest.fixture( + scope="function", params=[True, False], ids=["keep_warning_stat", "discard_warning_stat"] +) def keep_warning_stat(request): return request.param -@pytest.fixture(scope="function", params=[True, False]) +@pytest.fixture( + scope="function", params=[True, False], ids=["parallel_sampling", "sequential_sampling"] +) def parallel(request): return request.param -@pytest.fixture(scope="function", params=[True, False]) +@pytest.fixture(scope="function", params=[True, False], ids=["compute_loglike", "no_loglike"]) def log_likelihood(request): return request.param @@ -401,7 +406,7 @@ def test_sample( pytest.skip( reason="log_likelihood is only computed if an inference data object is returned" ) - store = zarr.MemoryStore() + store = zarr.TempStore() trace = ZarrTrace( store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk ) @@ -468,3 +473,66 @@ def test_sample( for name, v in out_trace.posterior.arrays() if name not in dimensions ) + + # Assert that the trace has valid sampling state stored for each chain + for step_method_state in trace._sampling_state.sampling_state[:]: + # We have no access to the actual step method that was using by each chain in pymc.sample + # The best way to see if the step method state is valid is by trying to set + # the model_step sampling state to the one stored in the trace. + model_step.sampling_state = step_method_state + + +def test_sampling_consistency( + model, + model_step, + draws_per_chunk, +): + # Test that pm.sample will generate the same posterior and sampling state + # regardless of whether sampling was done in parallel or not. + store1 = zarr.TempStore() + parallel_trace = ZarrTrace( + store=store1, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + store2 = zarr.TempStore() + sequential_trace = ZarrTrace( + store=store2, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + tune = 2 + draws = 3 + chains = 2 + random_seed = 12345 + initial_step_state = model_step.sampling_state + with model: + parallel_idata = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=chains, + trace=parallel_trace, + step=model_step, + discard_tuned_samples=True, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": False}, + random_seed=random_seed, + ) + model_step.sampling_state = initial_step_state + sequential_idata = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=1, + trace=sequential_trace, + step=model_step, + discard_tuned_samples=True, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": False}, + random_seed=random_seed, + ) + for chain in range(chains): + assert equal_sampling_states( + parallel_trace._sampling_state.sampling_state[chain], + sequential_trace._sampling_state.sampling_state[chain], + ) + xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior) From 97722ded26244e703f2db8856119d336d7c63d0d Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 7 Nov 2024 22:50:57 +0100 Subject: [PATCH 07/25] Make step method state keep track of var_names --- pymc/step_methods/compound.py | 2 ++ pymc/step_methods/state.py | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d0393afd57..1fcb3d2673 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -22,6 +22,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence +from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -96,6 +97,7 @@ def infer_warn_stats_info( @dataclass_state class StepMethodState(DataClassState): + var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True}) rng: RandomGeneratorState diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index e24276cf14..ec7bbbae48 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from dataclasses import Field, dataclass, fields +from dataclasses import MISSING, Field, dataclass, fields from typing import Any, ClassVar import numpy as np @@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState: state_class = self._state_class kwargs = {} for field in fields(state_class): - val = getattr(self, field.name) + is_tensor_name = field.metadata.get("tensor_name", False) + val: Any + if is_tensor_name: + val = [var.name for var in getattr(self, "vars")] + else: + val = getattr(self, field.name, field.default) + if val is MISSING: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {field.name!r}" + ) _val: Any if isinstance(val, WithSamplingState): _val = val.sampling_state @@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState): state, state_class ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" for field in fields(state_class): + is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) if isinstance(state_val, RandomGeneratorState): state_val = random_generator_from_state(state_val) - self_val = getattr(self, field.name) is_frozen = field.metadata.get("frozen", False) + self_val: Any + if is_tensor_name: + self_val = [var.name for var in getattr(self, "vars")] + assert is_frozen + else: + self_val = getattr(self, field.name, field.default) if is_frozen: if not equal_dataclass_values(state_val, self_val): raise ValueError( From 9d9233eb3d03039387c080e7dd27d27f91267501 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Sat, 21 Dec 2024 11:19:28 +0100 Subject: [PATCH 08/25] Precompile fn in ZarrChain --- pymc/backends/zarr.py | 17 ++++++++++++----- pymc/sampling/parallel.py | 6 ++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 1e3f4da883..e9aba5fe0d 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Callable, Mapping, MutableMapping, Sequence from typing import Any import arviz as az @@ -91,10 +91,11 @@ def __init__( vars: Sequence[TensorVariable] | None = None, test_point: dict[str, np.ndarray] | None = None, draws_per_chunk: int = 1, + fn: Callable | None = None, ): if not _zarr_available: raise RuntimeError("You must install zarr to be able to create ZarrChain instances") - super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) self._step_method: BlockedStep | CompoundStep | None = None self.unconstrained_variables = { var.name for var in self.vars if is_transformed_name(var.name) @@ -168,7 +169,7 @@ def record( :meth:`~ZarrChain.flush` """ unconstrained_variables = self.unconstrained_variables - for var_name, var_value in zip(self.varnames, self.fn(draw)): + for var_name, var_value in zip(self.varnames, self.fn(**draw)): if var_name in unconstrained_variables: self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) else: @@ -452,13 +453,18 @@ def init_trace( ) self.vars = [var for var in vars if var.name in self.varnames] - self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + self.fn = model.compile_fn( + self.vars, + inputs=model.value_vars, + on_unused_input="ignore", + point_fn=False, + ) # Get variable shapes. Most backends will need this # information. if test_point is None: test_point = model.initial_point() - var_values = list(zip(self.varnames, self.fn(test_point))) + var_values = list(zip(self.varnames, self.fn(**test_point))) self.var_dtype_shapes = { var: (value.dtype, value.shape) for var, value in var_values @@ -528,6 +534,7 @@ def init_trace( test_point=test_point, stats_bijection=StatsBijection(step.stats_dtypes), draws_per_chunk=self.draws_per_chunk, + fn=self.fn, ) for _ in range(chains) ] diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 794763e6e1..28e74d5e8a 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -110,8 +110,10 @@ def __init__( zarr_chains: list[ZarrChain] | bytes | None = None, zarr_chains_is_pickled: bool = False, ): - # For some strange reason, spawn multiprocessing doesn't copy the rng - # seed sequence, so we have to rebuild it from scratch + # Because of https://github.com/numpy/numpy/issues/27727, we can't send + # the rng instance to the child process because pickling (copying) looses + # the seed sequence state information. For this reason, we send a + # RandomGeneratorState instead. rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method From 671d7047ab580b6d94b6234450a47da701dcb51a Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Sun, 22 Dec 2024 23:26:38 +0100 Subject: [PATCH 09/25] Minor mcmc code cleanup --- pymc/sampling/mcmc.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 83a4835737..ca91325ff1 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1300,12 +1300,7 @@ def _iter_sample( ) yield diverging - except KeyboardInterrupt: - if isinstance(trace, ZarrChain): - trace.record_sampling_state(step=step) - trace.close() - raise - except BaseException: + except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) trace.close() From bd519d471a74d780315b04351dd045c19e7c3d41 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Tue, 14 Jan 2025 10:57:05 +0100 Subject: [PATCH 10/25] Fix conditional import of zarr --- pymc/backends/zarr.py | 37 +++++++++++++++++++++++-------------- pymc/sampling/mcmc.py | 6 +++++- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index e9aba5fe0d..b9c1e49ea3 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -15,14 +15,11 @@ from typing import Any import arviz as az -import numcodecs import numpy as np import xarray as xr -import zarr from arviz.data.base import make_attrs from arviz.data.inference_data import WARMUP_TAG -from numcodecs.abc import Codec from pytensor.tensor.variable import TensorVariable import pymc @@ -44,11 +41,23 @@ from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name try: + import numcodecs + import zarr + + from numcodecs.abc import Codec + from zarr import Group from zarr.storage import BaseStore, default_compressor from zarr.sync import Synchronizer _zarr_available = True except ImportError: + from typing import TYPE_CHECKING, TypeVar + + if not TYPE_CHECKING: + Codec = TypeVar("Codec") + Group = TypeVar("Group") + BaseStore = TypeVar("BaseStore") + Synchronizer = TypeVar("Synchronizer") _zarr_available = False @@ -243,7 +252,7 @@ def flush(self): def get_initial_fill_value_and_codec( dtype: Any, -) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]: _dtype = np.dtype(dtype) fill_value: FILL_VALUE_TYPE = None codec = None @@ -366,27 +375,27 @@ def groups(self) -> list[str]: return [str(group_name) for group_name, _ in self.root.groups()] @property - def posterior(self) -> zarr.Group: + def posterior(self) -> Group: return self.root.posterior @property - def unconstrained_posterior(self) -> zarr.Group: + def unconstrained_posterior(self) -> Group: return self.root.unconstrained_posterior @property - def sample_stats(self) -> zarr.Group: + def sample_stats(self) -> Group: return self.root.sample_stats @property - def constant_data(self) -> zarr.Group: + def constant_data(self) -> Group: return self.root.constant_data @property - def observed_data(self) -> zarr.Group: + def observed_data(self) -> Group: return self.root.observed_data @property - def _sampling_state(self) -> zarr.Group: + def _sampling_state(self) -> Group: return self.root._sampling_state def init_trace( @@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int): def init_group_with_empty( self, - group: zarr.Group, + group: Group, var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], chains: int, draws: int, extra_var_attrs: dict | None = None, - ) -> zarr.Group: + ) -> Group: group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} for name, (_dtype, shape) in var_dtype_and_shape.items(): fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) @@ -689,8 +698,8 @@ def init_group_with_empty( array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) return group - def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: - group: zarr.Group | None = None + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None: + group: Group | None = None if data_dict: group_coords = {} group = self.root.create_group(name=name, overwrite=True) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ca91325ff1..7cbb6df26e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -41,7 +41,6 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol -from zarr.storage import MemoryStore import pymc as pm @@ -80,6 +79,11 @@ ) from pymc.vartypes import discrete_types +try: + from zarr.storage import MemoryStore +except ImportError: + MemoryStore = type("MemoryStore", (), {}) + sys.setrecursionlimit(10000) __all__ = [ From e6767ab89dc2a49e726cc159590e3e53c93be914 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 17 Jan 2025 13:47:19 +0100 Subject: [PATCH 11/25] pre-commit update ruff 0.9.1 (#7648) Co-authored-by: Thomas Wiecki --- .pre-commit-config.yaml | 2 +- docs/source/learn/core_notebooks/pymc_pytensor.ipynb | 2 +- pymc/data.py | 2 +- pymc/distributions/continuous.py | 8 +++----- pymc/distributions/multivariate.py | 10 ++++++---- pymc/gp/cov.py | 3 +-- pymc/gp/util.py | 3 +-- pymc/sampling/jax.py | 5 ++--- pymc/sampling/mcmc.py | 6 +++--- pymc/sampling/population.py | 6 +++--- pymc/step_methods/compound.py | 6 +++--- pymc/step_methods/state.py | 6 +++--- pymc/testing.py | 6 +++--- pymc/variational/opvi.py | 8 ++++---- pymc/variational/updates.py | 2 +- tests/distributions/test_multivariate.py | 12 ++++++------ tests/gp/test_hsgp_approx.py | 12 ++++++------ tests/test_data.py | 6 +++--- 18 files changed, 51 insertions(+), 54 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10fd36fd94..2ba656a365 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: # - --exclude=binder/ # - --exclude=versioneer.py - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.9.1 hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/docs/source/learn/core_notebooks/pymc_pytensor.ipynb b/docs/source/learn/core_notebooks/pymc_pytensor.ipynb index aad72316a3..0260f960d1 100644 --- a/docs/source/learn/core_notebooks/pymc_pytensor.ipynb +++ b/docs/source/learn/core_notebooks/pymc_pytensor.ipynb @@ -1849,7 +1849,7 @@ "print(\n", " f\"\"\"\n", "mu_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=2)}\n", - "sigma_log_value -> {- 10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n", + "sigma_log_value -> {-10 + scipy.stats.halfnorm.logpdf(x=np.exp(-10), loc=0, scale=3)}\n", "x_value -> {scipy.stats.norm.logpdf(x=0, loc=0, scale=np.exp(-10))}\n", "\"\"\"\n", ")" diff --git a/pymc/data.py b/pymc/data.py index 997f0ccb3c..fd2ef8e82c 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -257,7 +257,7 @@ def determine_coords( if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: raise pm.exceptions.ShapeError( - "Invalid data shape. The rank of the dataset must match the " "length of `dims`.", + "Invalid data shape. The rank of the dataset must match the length of `dims`.", actual=value.shape, expected=value.ndim, ) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 3746f90fac..21a683ca99 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -992,8 +992,7 @@ def get_mu_lam_phi(mu, lam, phi): return mu, lam, lam / mu raise ValueError( - "Wald distribution must specify either mu only, " - "mu and lam, mu and phi, or lam and phi." + "Wald distribution must specify either mu only, mu and lam, mu and phi, or lam and phi." ) def logp(value, mu, lam, alpha): @@ -1603,8 +1602,7 @@ def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs): def get_kappa(cls, kappa=None, q=None): if kappa is not None and q is not None: raise ValueError( - "Incompatible parameterization. Either use " - "kappa or q to specify the distribution." + "Incompatible parameterization. Either use kappa or q to specify the distribution." ) elif q is not None: if isinstance(q, Variable): @@ -3483,7 +3481,7 @@ def get_nu_b(cls, nu, b, sigma): elif nu is not None and b is None: b = nu / sigma return nu, b, sigma - raise ValueError("Rice distribution must specify either nu" " or b.") + raise ValueError("Rice distribution must specify either nu or b.") def support_point(rv, size, nu, sigma): nu_sigma_ratio = -(nu**2) / (2 * sigma**2) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index e44008fe65..949c592aba 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -247,7 +247,9 @@ class MvNormal(Continuous): data = np.random.multivariate_normal(mu, true_cov, 10) sd_dist = pm.Exponential.dist(1.0, shape=3) - chol, corr, stds = pm.LKJCholeskyCov("chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True) + chol, corr, stds = pm.LKJCholeskyCov( + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) vals = pm.MvNormal("vals", mu=mu, chol=chol, observed=data) For unobserved values it can be better to use a non-centered @@ -2793,9 +2795,9 @@ def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs): support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1) - assert n_zerosum_axes == pt.get_vector_length( - support_shape - ), "support_shape has to be as long as n_zerosum_axes" + assert n_zerosum_axes == pt.get_vector_length(support_shape), ( + "support_shape has to be as long as n_zerosum_axes" + ) return super().dist([sigma, support_shape], **kwargs) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index d9f3577280..bc056be515 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -328,8 +328,7 @@ def power_spectral_density(self, omega: TensorLike) -> TensorVariable: check = Counter([isinstance(factor, Covariance) for factor in self._factor_list]) if check.get(True, 0) >= 2: raise NotImplementedError( - "The power spectral density of products of covariance " - "functions is not implemented." + "The power spectral density of products of covariance functions is not implemented." ) return reduce(mul, self._merge_factors_psd(omega)) diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 3aaf85ab54..b2d0486a2b 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -211,8 +211,7 @@ def plot_gp_dist( samples_kwargs = {} if np.any(np.isnan(samples)): warnings.warn( - "There are `nan` entries in the [samples] arguments. " - "The plot will not contain a band!", + "There are `nan` entries in the [samples] arguments. The plot will not contain a band!", UserWarning, ) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87f..4f8ae2a5af 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -108,8 +108,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl if any(var.default_update is not None for var in shared_variables): raise ValueError( - "Graph contains shared variables with default_update which cannot " - "be safely replaced." + "Graph contains shared variables with default_update which cannot be safely replaced." ) replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables} @@ -360,7 +359,7 @@ def _sample_blackjax_nuts( map_fn = jax.vmap else: raise ValueError( - "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' + "Only supporting the following methods to draw chains: 'parallel' or 'vectorized'" ) if chains == 1: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7cbb6df26e..64d6829fc8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1000,7 +1000,7 @@ def _sample_return( total_draws = draws_per_chain.sum() _log.info( - f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations " f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " f"took {t_sampling:.0f} seconds." ) @@ -1062,8 +1062,8 @@ def _sample_return( n_chains = len(mtrace.chains) _log.info( - f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' - f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " + f"Sampling {n_chains} chain{'s' if n_chains > 1 else ''} for {n_tune:_d} tune and {n_draws:_d} draw iterations " + f"({n_tune * n_chains:_d} + {n_draws * n_chains:_d} draws total) " f"took {t_sampling:.0f} seconds." ) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index b8a7ba593a..ab024f1e4f 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -386,9 +386,9 @@ def _prepare_iter_population( # 2. Set up the steppers steppers: list[Step] = [] - assert ( - len(rngs) == nchains - ), f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}" + assert len(rngs) == nchains, ( + f"There must be one random Generator per chain. Got {len(rngs)} instead of {nchains}" + ) for c, rng in enumerate(rngs): # need independent samplers for each chain # it is important to copy the actual steppers (but not the delta_logp) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 1fcb3d2673..b823a00be8 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -282,9 +282,9 @@ def sampling_state(self) -> DataClassState: @sampling_state.setter def sampling_state(self, state: DataClassState): - assert isinstance( - state, self._state_class - ), f"Invalid sampling state class {type(state)}. Expected {self._state_class}" + assert isinstance(state, self._state_class), ( + f"Invalid sampling state class {type(state)}. Expected {self._state_class}" + ) for method, state_method in zip(self.methods, state.methods): method.sampling_state = state_method diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index ec7bbbae48..db62ffda91 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -90,9 +90,9 @@ def sampling_state(self) -> DataClassState: @sampling_state.setter def sampling_state(self, state: DataClassState): state_class = self._state_class - assert isinstance( - state, state_class - ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" + assert isinstance(state, state_class), ( + f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" + ) for field in fields(state_class): is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) diff --git a/pymc/testing.py b/pymc/testing.py index cc7433980c..5e0fa1ab0c 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -964,9 +964,9 @@ def check_rv_size(self): assert actual == expected_symbolic == expected def validate_tests_list(self): - assert len(self.checks_to_run) == len( - set(self.checks_to_run) - ), "There are duplicates in the list of checks_to_run" + assert len(self.checks_to_run) == len(set(self.checks_to_run)), ( + "There are duplicates in the list of checks_to_run" + ) def seeded_scipy_distribution_builder(dist_name: str) -> Callable: diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 034e2fed87..a054f51e62 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -710,9 +710,9 @@ class Group(WithMemoization): @classmethod def register(cls, sbcls): - assert ( - frozenset(sbcls.__param_spec__) not in cls.__param_registry - ), "Duplicate __param_spec__" + assert frozenset(sbcls.__param_spec__) not in cls.__param_registry, ( + "Duplicate __param_spec__" + ) cls.__param_registry[frozenset(sbcls.__param_spec__)] = sbcls assert sbcls.short_name not in cls.__name_registry, "Duplicate short_name" cls.__name_registry[sbcls.short_name] = sbcls @@ -1234,7 +1234,7 @@ def __init__(self, groups, model=None): for g in groups: if g.group is None: if rest is not None: - raise GroupError("More than one group is specified for " "the rest variables") + raise GroupError("More than one group is specified for the rest variables") else: rest = g else: diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py index 234d307500..07c241beca 100644 --- a/pymc/variational/updates.py +++ b/pymc/variational/updates.py @@ -1006,7 +1006,7 @@ def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7): elif ndim in [3, 4, 5]: # Conv{1,2,3}DLayer sum_over = tuple(range(1, ndim)) else: - raise ValueError(f"Unsupported tensor dimensionality {ndim}." "Must specify `norm_axes`") + raise ValueError(f"Unsupported tensor dimensionality {ndim}. Must specify `norm_axes`") dtype = np.dtype(pytensor.config.floatX).type norms = pt.sqrt(pt.sum(pt.sqr(tensor_var), axis=sum_over, keepdims=True)) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index cfd50fdd71..d988718fed 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1531,14 +1531,14 @@ class TestZeroSumNormal: def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): if check_zerosum_axes: for ax in axes_to_check: - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + assert np.isclose(random_samples.mean(axis=ax), 0).all(), ( + f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + ) else: for ax in axes_to_check: - assert not np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + assert not np.isclose(random_samples.mean(axis=ax), 0).all(), ( + f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + ) @pytest.mark.parametrize( "dims, n_zerosum_axes", diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py index db03c8b8bc..b18577cde5 100644 --- a/tests/gp/test_hsgp_approx.py +++ b/tests/gp/test_hsgp_approx.py @@ -135,9 +135,9 @@ def test_mean_invariance(self): with model: pm.set_data({"X": x_new}) - assert np.allclose( - gp._X_center, original_center - ), "gp._X_center should not change after updating data for out-of-sample predictions." + assert np.allclose(gp._X_center, original_center), ( + "gp._X_center should not change after updating data for out-of-sample predictions." + ) def test_parametrization(self): err_msg = ( @@ -188,9 +188,9 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first): n_coeffs = model.f1_hsgp_coeffs.type.shape[0] if drop_first: - assert ( - n_coeffs == n_basis - 1 - ), f"one basis vector should have been dropped, {n_coeffs}" + assert n_coeffs == n_basis - 1, ( + f"one basis vector should have been dropped, {n_coeffs}" + ) else: assert n_coeffs == n_basis, "one was dropped when it shouldn't have been" diff --git a/tests/test_data.py b/tests/test_data.py index 2ba66dc744..695058c87e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -318,8 +318,8 @@ def test_explicit_coords(self, seeded_test): N_cols = 7 data = np.random.uniform(size=(N_rows, N_cols)) coords = { - "rows": [f"R{r+1}" for r in range(N_rows)], - "columns": [f"C{c+1}" for c in range(N_cols)], + "rows": [f"R{r + 1}" for r in range(N_rows)], + "columns": [f"C{c + 1}" for c in range(N_cols)], } # pass coordinates explicitly, use numpy array in Data container with pm.Model(coords=coords) as pmodel: @@ -391,7 +391,7 @@ def test_implicit_coords_dataframe(self, seeded_test): N_cols = 7 df_data = pd.DataFrame() for c in range(N_cols): - df_data[f"Column {c+1}"] = np.random.normal(size=(N_rows,)) + df_data[f"Column {c + 1}"] = np.random.normal(size=(N_rows,)) df_data.index.name = "rows" df_data.columns.name = "columns" From aad5400997b3723129252d42c6f9f72966087cfc Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 15 Jan 2025 15:12:52 +0100 Subject: [PATCH 12/25] Use updated head_of_apache pre-commit step --- .pre-commit-config.yaml | 21 ++++++++++---------- benchmarks/benchmarks/__init__.py | 2 +- benchmarks/benchmarks/benchmarks.py | 2 +- pymc/__init__.py | 2 +- pymc/_version.py | 2 +- pymc/backends/__init__.py | 2 +- pymc/backends/arviz.py | 2 +- pymc/backends/base.py | 2 +- pymc/backends/mcbackend.py | 2 +- pymc/backends/ndarray.py | 2 +- pymc/backends/report.py | 2 +- pymc/backends/zarr.py | 2 +- pymc/blocking.py | 2 +- pymc/data.py | 2 +- pymc/distributions/__init__.py | 2 +- pymc/distributions/censored.py | 2 +- pymc/distributions/continuous.py | 2 +- pymc/distributions/custom.py | 2 +- pymc/distributions/discrete.py | 2 +- pymc/distributions/dist_math.py | 2 +- pymc/distributions/distribution.py | 2 +- pymc/distributions/mixture.py | 2 +- pymc/distributions/moments/__init__.py | 2 +- pymc/distributions/moments/means.py | 2 +- pymc/distributions/multivariate.py | 2 +- pymc/distributions/shape_utils.py | 2 +- pymc/distributions/simulator.py | 2 +- pymc/distributions/timeseries.py | 2 +- pymc/distributions/transforms.py | 2 +- pymc/distributions/truncated.py | 2 +- pymc/exceptions.py | 2 +- pymc/func_utils.py | 2 +- pymc/gp/__init__.py | 2 +- pymc/gp/cov.py | 2 +- pymc/gp/gp.py | 2 +- pymc/gp/hsgp_approx.py | 2 +- pymc/gp/mean.py | 2 +- pymc/gp/util.py | 2 +- pymc/initial_point.py | 2 +- pymc/logprob/__init__.py | 2 +- pymc/logprob/abstract.py | 2 +- pymc/logprob/basic.py | 2 +- pymc/logprob/binary.py | 2 +- pymc/logprob/censoring.py | 2 +- pymc/logprob/checks.py | 2 +- pymc/logprob/cumsum.py | 2 +- pymc/logprob/linalg.py | 2 +- pymc/logprob/mixture.py | 2 +- pymc/logprob/order.py | 2 +- pymc/logprob/rewriting.py | 2 +- pymc/logprob/scan.py | 2 +- pymc/logprob/tensor.py | 2 +- pymc/logprob/transform_value.py | 2 +- pymc/logprob/transforms.py | 2 +- pymc/logprob/utils.py | 2 +- pymc/math.py | 2 +- pymc/model/__init__.py | 2 +- pymc/model/core.py | 2 +- pymc/model/fgraph.py | 2 +- pymc/model/transform/__init__.py | 2 +- pymc/model/transform/basic.py | 2 +- pymc/model/transform/conditioning.py | 2 +- pymc/model/transform/optimization.py | 2 +- pymc/model_graph.py | 2 +- pymc/ode/__init__.py | 2 +- pymc/ode/ode.py | 2 +- pymc/ode/utils.py | 2 +- pymc/plots/__init__.py | 2 +- pymc/printing.py | 2 +- pymc/pytensorf.py | 2 +- pymc/sampling/__init__.py | 2 +- pymc/sampling/deterministic.py | 2 +- pymc/sampling/forward.py | 2 +- pymc/sampling/jax.py | 2 +- pymc/sampling/mcmc.py | 2 +- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 2 +- pymc/smc/__init__.py | 2 +- pymc/smc/kernels.py | 2 +- pymc/smc/sampling.py | 2 +- pymc/stats/__init__.py | 2 +- pymc/stats/convergence.py | 2 +- pymc/stats/log_density.py | 2 +- pymc/step_methods/__init__.py | 2 +- pymc/step_methods/arraystep.py | 2 +- pymc/step_methods/compound.py | 2 +- pymc/step_methods/hmc/__init__.py | 2 +- pymc/step_methods/hmc/base_hmc.py | 2 +- pymc/step_methods/hmc/hmc.py | 2 +- pymc/step_methods/hmc/integration.py | 2 +- pymc/step_methods/hmc/nuts.py | 2 +- pymc/step_methods/hmc/quadpotential.py | 2 +- pymc/step_methods/metropolis.py | 2 +- pymc/step_methods/slicer.py | 2 +- pymc/step_methods/state.py | 2 +- pymc/step_methods/step_sizes.py | 2 +- pymc/testing.py | 2 +- pymc/tuning/__init__.py | 2 +- pymc/tuning/scaling.py | 2 +- pymc/tuning/starting.py | 2 +- pymc/util.py | 2 +- pymc/variational/__init__.py | 2 +- pymc/variational/approximations.py | 2 +- pymc/variational/callbacks.py | 2 +- pymc/variational/inference.py | 2 +- pymc/variational/minibatch_rv.py | 2 +- pymc/variational/operators.py | 2 +- pymc/variational/opvi.py | 2 +- pymc/variational/stein.py | 2 +- pymc/variational/test_functions.py | 2 +- pymc/variational/updates.py | 2 +- pymc/vartypes.py | 2 +- setup.py | 2 +- setupegg.py | 2 +- tests/__init__.py | 2 +- tests/backends/__init__.py | 2 +- tests/backends/fixtures.py | 2 +- tests/backends/test_arviz.py | 2 +- tests/backends/test_base.py | 2 +- tests/backends/test_mcbackend.py | 2 +- tests/backends/test_ndarray.py | 2 +- tests/backends/test_zarr.py | 2 +- tests/conftest.py | 2 +- tests/distributions/__init__.py | 2 +- tests/distributions/moments/__init__.py | 2 +- tests/distributions/moments/test_means.py | 2 +- tests/distributions/test_censored.py | 2 +- tests/distributions/test_continuous.py | 2 +- tests/distributions/test_custom.py | 2 +- tests/distributions/test_discrete.py | 2 +- tests/distributions/test_dist_math.py | 2 +- tests/distributions/test_distribution.py | 2 +- tests/distributions/test_mixture.py | 2 +- tests/distributions/test_multivariate.py | 2 +- tests/distributions/test_shape_utils.py | 2 +- tests/distributions/test_simulator.py | 2 +- tests/distributions/test_timeseries.py | 2 +- tests/distributions/test_transform.py | 2 +- tests/distributions/test_truncated.py | 2 +- tests/gp/__init__.py | 2 +- tests/gp/test_cov.py | 2 +- tests/gp/test_gp.py | 2 +- tests/gp/test_hsgp_approx.py | 2 +- tests/gp/test_mean.py | 2 +- tests/gp/test_util.py | 2 +- tests/helpers.py | 2 +- tests/logprob/__init__.py | 2 +- tests/logprob/test_abstract.py | 2 +- tests/logprob/test_basic.py | 2 +- tests/logprob/test_binary.py | 2 +- tests/logprob/test_censoring.py | 2 +- tests/logprob/test_checks.py | 2 +- tests/logprob/test_composite_logprob.py | 2 +- tests/logprob/test_cumsum.py | 2 +- tests/logprob/test_linalg.py | 2 +- tests/logprob/test_mixture.py | 2 +- tests/logprob/test_order.py | 2 +- tests/logprob/test_rewriting.py | 2 +- tests/logprob/test_scan.py | 2 +- tests/logprob/test_tensor.py | 2 +- tests/logprob/test_transform_value.py | 2 +- tests/logprob/test_transforms.py | 2 +- tests/logprob/test_utils.py | 2 +- tests/logprob/utils.py | 2 +- tests/model/__init__.py | 2 +- tests/model/test_core.py | 2 +- tests/model/test_fgraph.py | 2 +- tests/model/transform/__init__.py | 2 +- tests/model/transform/test_basic.py | 2 +- tests/model/transform/test_conditioning.py | 2 +- tests/model/transform/test_optimization.py | 2 +- tests/models.py | 2 +- tests/ode/__init__.py | 2 +- tests/ode/test_ode.py | 2 +- tests/ode/test_utils.py | 2 +- tests/sampler_fixtures.py | 2 +- tests/sampling/__init__.py | 2 +- tests/sampling/test_deterministic.py | 2 +- tests/sampling/test_forward.py | 2 +- tests/sampling/test_jax.py | 2 +- tests/sampling/test_mcmc.py | 2 +- tests/sampling/test_mcmc_external.py | 2 +- tests/sampling/test_parallel.py | 2 +- tests/sampling/test_population.py | 2 +- tests/smc/__init__.py | 2 +- tests/smc/test_smc.py | 2 +- tests/stats/__init__.py | 2 +- tests/stats/test_convergence.py | 2 +- tests/stats/test_log_density.py | 2 +- tests/step_methods/__init__.py | 2 +- tests/step_methods/hmc/__init__.py | 2 +- tests/step_methods/hmc/test_hmc.py | 2 +- tests/step_methods/hmc/test_nuts.py | 2 +- tests/step_methods/hmc/test_quadpotential.py | 2 +- tests/step_methods/test_compound.py | 2 +- tests/step_methods/test_metropolis.py | 2 +- tests/step_methods/test_slicer.py | 2 +- tests/step_methods/test_state.py | 2 +- tests/test_data.py | 2 +- tests/test_func_utils.py | 2 +- tests/test_initial_point.py | 2 +- tests/test_math.py | 2 +- tests/test_model_graph.py | 2 +- tests/test_printing.py | 2 +- tests/test_pytensorf.py | 2 +- tests/test_testing.py | 2 +- tests/test_util.py | 2 +- tests/tuning/__init__.py | 2 +- tests/tuning/test_scaling.py | 2 +- tests/tuning/test_starting.py | 2 +- tests/variational/__init__.py | 2 +- tests/variational/test_approximations.py | 2 +- tests/variational/test_callbacks.py | 2 +- tests/variational/test_inference.py | 2 +- tests/variational/test_minibatch_rv.py | 2 +- tests/variational/test_opvi.py | 2 +- tests/variational/test_updates.py | 2 +- 217 files changed, 227 insertions(+), 226 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ba656a365..7a3b1a7df6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,16 +37,17 @@ repos: hooks: - id: sphinx-lint args: ["."] -#- repo: https://github.com/lucianopaz/head_of_apache -# rev: "0.0.3" -# hooks: -# - id: head_of_apache -# args: -# - --author=The PyMC Developers -# - --exclude=docs/ -# - --exclude=scripts/ -# - --exclude=binder/ -# - --exclude=versioneer.py +- repo: https://github.com/lucianopaz/head_of_apache + rev: "0.1.0" + hooks: + - id: head_of_apache + args: + - --author=The PyMC Developers + - --exclude=docs/ + - --exclude=scripts/ + - --exclude=binder/ + - --exclude=versioneer.py + - --last-year-present - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.1 hooks: diff --git a/benchmarks/benchmarks/__init__.py b/benchmarks/benchmarks/__init__.py index 1217c81ed2..7443280d34 100644 --- a/benchmarks/benchmarks/__init__.py +++ b/benchmarks/benchmarks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/benchmarks/benchmarks/benchmarks.py b/benchmarks/benchmarks/benchmarks.py index 7485cef65e..381dd7ee5e 100644 --- a/benchmarks/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks/benchmarks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/__init__.py b/pymc/__init__.py index a828b72827..684feac11f 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/_version.py b/pymc/_version.py index 2f7f80bfad..4de3063d04 100644 --- a/pymc/_version.py +++ b/pymc/_version.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 882412ce2d..d3f7620882 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index d1c27b787b..f0f0eec963 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 5a2a043a39..993acc0df4 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 3d2c8fd9e7..9bc8ff1043 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 70ca60879c..a08fc8f47e 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/report.py b/pymc/backends/report.py index 9a630ee242..d2ab860bd3 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index b9c1e49ea3..9b7664c504 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/blocking.py b/pymc/blocking.py index 2aad656128..9f9e27cebf 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/data.py b/pymc/data.py index fd2ef8e82c..9373eb5775 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 442ebddc71..c578267091 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 4be21b1c9d..77c52023b9 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 21a683ca99..082be31d5c 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index 3238680bb3..86aba12043 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 979f81dba0..d2f35c8007 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index 1cdb3b2945..3f675406f4 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 0d1c58cf17..b2ec6fb79b 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index dc704e5121..303f2793db 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/moments/__init__.py b/pymc/distributions/moments/__init__.py index 8aafdb37a2..b61e8b6400 100644 --- a/pymc/distributions/moments/__init__.py +++ b/pymc/distributions/moments/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py index f025733726..f183ace5db 100644 --- a/pymc/distributions/moments/means.py +++ b/pymc/distributions/moments/means.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 949c592aba..32f9e30f06 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index f2b21763c0..7dd0d94414 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/simulator.py b/pymc/distributions/simulator.py index aeacf8346a..9e7fcd08d7 100644 --- a/pymc/distributions/simulator.py +++ b/pymc/distributions/simulator.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 3ec863d7a5..fb5e5e7420 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index fe036c2bc4..c8ca8d0554 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 36b4395263..4b984e4c41 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/exceptions.py b/pymc/exceptions.py index 652c2ae5a4..b07c0c2887 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/func_utils.py b/pymc/func_utils.py index 72dc3b1a96..e7acb3dff0 100644 --- a/pymc/func_utils.py +++ b/pymc/func_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/__init__.py b/pymc/gp/__init__.py index 15a49efeb6..4ddf1c9e42 100644 --- a/pymc/gp/__init__.py +++ b/pymc/gp/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index bc056be515..16e51cdde9 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/gp.py b/pymc/gp/gp.py index 3a4b453839..c31a6a613e 100644 --- a/pymc/gp/gp.py +++ b/pymc/gp/gp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/hsgp_approx.py b/pymc/gp/hsgp_approx.py index e331aab8f5..9bad17ce87 100644 --- a/pymc/gp/hsgp_approx.py +++ b/pymc/gp/hsgp_approx.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/mean.py b/pymc/gp/mean.py index 827b5db6e4..01871afa8e 100644 --- a/pymc/gp/mean.py +++ b/pymc/gp/mean.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/gp/util.py b/pymc/gp/util.py index b2d0486a2b..da713b886b 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f683..ba3a0ea85c 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 4dea34312f..2e67a6c55b 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 281b4fb184..5c7f28e661 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 7753678d2e..6fd4a5489e 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 9d0985a2cf..27449e2d2c 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index e17d30a43d..411b8162a8 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index c8c21ef61c..9ac12a1e10 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 4fd5a6eaeb..9dd92611ad 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/linalg.py b/pymc/logprob/linalg.py index 226b24a07d..bd5ac33261 100644 --- a/pymc/logprob/linalg.py +++ b/pymc/logprob/linalg.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 1ebb29638e..ce6a11d208 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 6eceb819dd..30f6a32565 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index 76baf31dfa..b5a6b23a09 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 8f6942458e..27d562523a 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index abb5df2ab5..3d577aac9a 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 1b5d4cd817..4a28d5cd4a 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 930bf1f4eb..8446874996 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 9865226e42..e1fdc903ee 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/math.py b/pymc/math.py index 2f7527e113..1845dd5111 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/__init__.py b/pymc/model/__init__.py index 4caa701378..f824e50b4f 100644 --- a/pymc/model/__init__.py +++ b/pymc/model/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/core.py b/pymc/model/core.py index 99711e566e..cef6cd6d1b 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index fa5e34a7d1..5dc47fe0ee 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/transform/__init__.py b/pymc/model/transform/__init__.py index 008e6f8ff0..6c6610b707 100644 --- a/pymc/model/transform/__init__.py +++ b/pymc/model/transform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 3d756785a5..fcf42fdf8c 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index 23e0175503..4d9f0553a1 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model/transform/optimization.py b/pymc/model/transform/optimization.py index 187e4ee444..9848d89d7f 100644 --- a/pymc/model/transform/optimization.py +++ b/pymc/model/transform/optimization.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/model_graph.py b/pymc/model_graph.py index b3b9847727..f31a33770d 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/ode/__init__.py b/pymc/ode/__init__.py index d10034e3f4..e91751ec8b 100644 --- a/pymc/ode/__init__.py +++ b/pymc/ode/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/ode/ode.py b/pymc/ode/ode.py index ca01af13b6..7c7e858ad0 100644 --- a/pymc/ode/ode.py +++ b/pymc/ode/ode.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/ode/utils.py b/pymc/ode/utils.py index 3ad05b1e14..1ba99feab0 100644 --- a/pymc/ode/utils.py +++ b/pymc/ode/utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/plots/__init__.py b/pymc/plots/__init__.py index cc938faa94..49068d5369 100644 --- a/pymc/plots/__init__.py +++ b/pymc/plots/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/printing.py b/pymc/printing.py index 946a8a213b..c4376d306e 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f665d5931c..1f390b1771 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/__init__.py b/pymc/sampling/__init__.py index bb5206ecc8..362f7a8327 100644 --- a/pymc/sampling/__init__.py +++ b/pymc/sampling/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/deterministic.py b/pymc/sampling/deterministic.py index 3d8398c3a7..85ce0a1164 100644 --- a/pymc/sampling/deterministic.py +++ b/pymc/sampling/deterministic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c07683555a..0abe0d66b3 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 4f8ae2a5af..03413667df 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 64d6829fc8..758cb86448 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 28e74d5e8a..3c2a8c9a36 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index ab024f1e4f..92de63d0c2 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/smc/__init__.py b/pymc/smc/__init__.py index 4d6f90eab3..d3a0803305 100644 --- a/pymc/smc/__init__.py +++ b/pymc/smc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index db1b0cf5bb..a5c86b5609 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index b1ae52e030..a4e8248814 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/stats/__init__.py b/pymc/stats/__init__.py index 4b94e3e064..1ded55a5e2 100644 --- a/pymc/stats/__init__.py +++ b/pymc/stats/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index eee6677825..d32831c8be 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index 266ceaac1f..c4b1048d9c 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 47fabc10dd..733eed5ed6 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index d15b14499c..0c20e09a47 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index b823a00be8..ff3f9c66a5 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/__init__.py b/pymc/step_methods/hmc/__init__.py index 8ec9f91ace..e51cef7784 100644 --- a/pymc/step_methods/hmc/__init__.py +++ b/pymc/step_methods/hmc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 564daebed4..e8c96e8c4b 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index a5ebbd7a8c..565c1fd78b 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 067cd239f8..4eb7a15d8f 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 770605f4b7..bbda728e80 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index 2c1b500cc6..dd7ad6922b 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 64455c8930..8e22218a13 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 73574c025b..ecc7967614 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index db62ffda91..98e177aa03 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/step_methods/step_sizes.py b/pymc/step_methods/step_sizes.py index c0fdb934a3..6cfd79c19a 100644 --- a/pymc/step_methods/step_sizes.py +++ b/pymc/step_methods/step_sizes.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/testing.py b/pymc/testing.py index 5e0fa1ab0c..7ef6751892 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/tuning/__init__.py b/pymc/tuning/__init__.py index f2920849b9..ac6f66a5ce 100644 --- a/pymc/tuning/__init__.py +++ b/pymc/tuning/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/tuning/scaling.py b/pymc/tuning/scaling.py index d07f8c8645..3fbbff61d7 100644 --- a/pymc/tuning/scaling.py +++ b/pymc/tuning/scaling.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 22d3ffb415..2fbbba6339 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/util.py b/pymc/util.py index 63576676eb..8dc7d16804 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index 785fb11cb0..23fd0212a3 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 6fad3c10b1..29b7093108 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/callbacks.py b/pymc/variational/callbacks.py index caf6c89cce..8beba7bd92 100644 --- a/pymc/variational/callbacks.py +++ b/pymc/variational/callbacks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 29800e0541..d9da7fb786 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index f9227f8131..163cec4727 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index fc1226be1f..502fe13ab9 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index a054f51e62..f5ef6a4205 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py index bf3a41ca0f..0534bb6fa4 100644 --- a/pymc/variational/stein.py +++ b/pymc/variational/stein.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/test_functions.py b/pymc/variational/test_functions.py index 26ad061931..f8e80f5b86 100644 --- a/pymc/variational/test_functions.py +++ b/pymc/variational/test_functions.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py index 07c241beca..fa48a3b3f3 100644 --- a/pymc/variational/updates.py +++ b/pymc/variational/updates.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc/vartypes.py b/pymc/vartypes.py index 2f145aa9b7..955c86a4a4 100644 --- a/pymc/vartypes.py +++ b/pymc/vartypes.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/setup.py b/setup.py index 8482d00d19..99bcadd86a 100755 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/setupegg.py b/setupegg.py index c263f95845..135a9344ff 100755 --- a/setupegg.py +++ b/setupegg.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/__init__.py b/tests/__init__.py index 997d5084cb..9bc5330969 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/backends/__init__.py +++ b/tests/backends/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/fixtures.py b/tests/backends/fixtures.py index c7a3bdcec3..a4f28a1262 100644 --- a/tests/backends/fixtures.py +++ b/tests/backends/fixtures.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 18599738ae..3c06288b35 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/test_base.py b/tests/backends/test_base.py index 3dd6c6e644..0f450119a7 100644 --- a/tests/backends/test_base.py +++ b/tests/backends/test_base.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index 23240af377..e72731af6b 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/test_ndarray.py b/tests/backends/test_ndarray.py index f1050ef0f1..a74e9f3ad7 100644 --- a/tests/backends/test_ndarray.py +++ b/tests/backends/test_ndarray.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index 32f508ef1a..af9c9e0a06 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/conftest.py b/tests/conftest.py index 306102e3bb..24175ee0db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/__init__.py b/tests/distributions/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/distributions/__init__.py +++ b/tests/distributions/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/moments/__init__.py b/tests/distributions/moments/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/distributions/moments/__init__.py +++ b/tests/distributions/moments/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/moments/test_means.py b/tests/distributions/moments/test_means.py index f3a9ebe73c..abfa9ee376 100644 --- a/tests/distributions/moments/test_means.py +++ b/tests/distributions/moments/test_means.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 9ce836cfc8..6e8b0f9dcd 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 2864335e34..cfdd0b3d60 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py index 5f2cef54f0..d3de7cf4f7 100644 --- a/tests/distributions/test_custom.py +++ b/tests/distributions/test_custom.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index e9be2ceded..24eeb504c9 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_dist_math.py b/tests/distributions/test_dist_math.py index be98fccb3d..39b9cfdd04 100644 --- a/tests/distributions/test_dist_math.py +++ b/tests/distributions/test_dist_math.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index cd45b54d49..df97905073 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 0e247e5e56..7fd00bcb5a 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index d988718fed..39b6c562e1 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index f381d6db48..8579bfd8e1 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_simulator.py b/tests/distributions/test_simulator.py index 8732613864..29f2f3c229 100644 --- a/tests/distributions/test_simulator.py +++ b/tests/distributions/test_simulator.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_timeseries.py b/tests/distributions/test_timeseries.py index 580b783d04..197296a5a4 100644 --- a/tests/distributions/test_timeseries.py +++ b/tests/distributions/test_timeseries.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index f1d71504ce..12d9b438b5 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index 5ef28791d9..0be89abd8f 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/__init__.py b/tests/gp/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/gp/__init__.py +++ b/tests/gp/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/test_cov.py b/tests/gp/test_cov.py index 5a0d962747..9334d05831 100644 --- a/tests/gp/test_cov.py +++ b/tests/gp/test_cov.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/test_gp.py b/tests/gp/test_gp.py index 3d620610fd..1b86103d3b 100644 --- a/tests/gp/test_gp.py +++ b/tests/gp/test_gp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py index b18577cde5..84ad396b1c 100644 --- a/tests/gp/test_hsgp_approx.py +++ b/tests/gp/test_hsgp_approx.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/test_mean.py b/tests/gp/test_mean.py index af7943397e..83422b0a72 100644 --- a/tests/gp/test_mean.py +++ b/tests/gp/test_mean.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/gp/test_util.py b/tests/gp/test_util.py index 5abbdd6fd5..561535c266 100644 --- a/tests/gp/test_util.py +++ b/tests/gp/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/helpers.py b/tests/helpers.py index e4b6248930..d2638e7427 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/__init__.py b/tests/logprob/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/logprob/__init__.py +++ b/tests/logprob/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 3976066e60..5d8024cdca 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index cfbd70b504..64cbf63b3e 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index e56a248741..b8069517e3 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index ccbbb38bc2..c778f7a9b4 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index db60c573e1..3a13f2a52a 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_composite_logprob.py b/tests/logprob/test_composite_logprob.py index b249a167fe..7155904e45 100644 --- a/tests/logprob/test_composite_logprob.py +++ b/tests/logprob/test_composite_logprob.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index 552cea92d0..a1b7b8d900 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_linalg.py b/tests/logprob/test_linalg.py index 047a0312b9..ca1741f11c 100644 --- a/tests/logprob/test_linalg.py +++ b/tests/logprob/test_linalg.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 61a78bf4db..ffb2bf07c0 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index e08bbf4571..1b1fbac636 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index 5f1fe55586..e754a7d3f3 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 17fb198ca2..6190808da7 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index e118ed69f2..df5c7052f8 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_transform_value.py b/tests/logprob/test_transform_value.py index 491a38086c..19dcf840e9 100644 --- a/tests/logprob/test_transform_value.py +++ b/tests/logprob/test_transform_value.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 17fe096e92..2ab30235bd 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index a982076db7..e47bd58142 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index e5aa36b830..a9f2ae92c7 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/__init__.py b/tests/model/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/model/__init__.py +++ b/tests/model/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 2d3786637f..9b9c673f30 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py index 9a65be36b7..178eb39683 100644 --- a/tests/model/test_fgraph.py +++ b/tests/model/test_fgraph.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/transform/__init__.py b/tests/model/transform/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/model/transform/__init__.py +++ b/tests/model/transform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index b62edaafc6..25bf2324ec 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/transform/test_conditioning.py b/tests/model/transform/test_conditioning.py index 2aba88b99d..c3dee552b8 100644 --- a/tests/model/transform/test_conditioning.py +++ b/tests/model/transform/test_conditioning.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/model/transform/test_optimization.py b/tests/model/transform/test_optimization.py index 9b697f6305..428bc868ce 100644 --- a/tests/model/transform/test_optimization.py +++ b/tests/model/transform/test_optimization.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/models.py b/tests/models.py index abf461fa90..fd3c52aa65 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/ode/__init__.py b/tests/ode/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/ode/__init__.py +++ b/tests/ode/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/ode/test_ode.py b/tests/ode/test_ode.py index 42f4d35fb7..affe5380fd 100644 --- a/tests/ode/test_ode.py +++ b/tests/ode/test_ode.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/ode/test_utils.py b/tests/ode/test_utils.py index d6edafbd43..4b938b1202 100644 --- a/tests/ode/test_utils.py +++ b/tests/ode/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampler_fixtures.py b/tests/sampler_fixtures.py index bd3269364c..d6fb6255b9 100644 --- a/tests/sampler_fixtures.py +++ b/tests/sampler_fixtures.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/__init__.py b/tests/sampling/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/sampling/__init__.py +++ b/tests/sampling/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_deterministic.py b/tests/sampling/test_deterministic.py index f42e1d7eba..dcfb548cf8 100644 --- a/tests/sampling/test_deterministic.py +++ b/tests/sampling/test_deterministic.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 404f74a961..bee0d2f792 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index d6a8d1021b..ddec60e539 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 41b068e042..409ab6ff86 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 3305d018f1..4ab3ed5e87 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index c16489610f..795c08ca0c 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/sampling/test_population.py b/tests/sampling/test_population.py index 4e3d91bcbb..f9ea506d22 100644 --- a/tests/sampling/test_population.py +++ b/tests/sampling/test_population.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 09ba48dc7d..493c0c8daa 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/stats/__init__.py b/tests/stats/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/stats/__init__.py +++ b/tests/stats/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index f468fc5e5b..1f7ba44791 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/stats/test_log_density.py b/tests/stats/test_log_density.py index 00ee5d4995..7b2eb3774e 100644 --- a/tests/stats/test_log_density.py +++ b/tests/stats/test_log_density.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/__init__.py b/tests/step_methods/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/step_methods/__init__.py +++ b/tests/step_methods/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/hmc/__init__.py b/tests/step_methods/hmc/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/step_methods/hmc/__init__.py +++ b/tests/step_methods/hmc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/hmc/test_hmc.py b/tests/step_methods/hmc/test_hmc.py index d228820328..774bb40927 100644 --- a/tests/step_methods/hmc/test_hmc.py +++ b/tests/step_methods/hmc/test_hmc.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index d37782fb78..432418a33a 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/hmc/test_quadpotential.py b/tests/step_methods/hmc/test_quadpotential.py index cafb788ccf..e9cf3fa5d6 100644 --- a/tests/step_methods/hmc/test_quadpotential.py +++ b/tests/step_methods/hmc/test_quadpotential.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 556deffe32..6c8957f9b3 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 0a81797b3c..234dabb5a4 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/test_slicer.py b/tests/step_methods/test_slicer.py index 899d4ec9ec..4f1ccb1110 100644 --- a/tests/step_methods/test_slicer.py +++ b/tests/step_methods/test_slicer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/step_methods/test_state.py b/tests/step_methods/test_state.py index e6a39264db..dd351bb555 100644 --- a/tests/step_methods/test_state.py +++ b/tests/step_methods/test_state.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_data.py b/tests/test_data.py index 695058c87e..5d370d02c0 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_func_utils.py b/tests/test_func_utils.py index ff4de87096..e0815c86c7 100644 --- a/tests/test_func_utils.py +++ b/tests/test_func_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 9138f37b3e..bda39bd1f6 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_math.py b/tests/test_math.py index 40c3b70db5..3f811fc2b7 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 866253f4e7..70c10d4106 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_printing.py b/tests/test_printing.py index 832699c20d..917a2d1ee2 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index f8353ce9cc..0ea18dabe3 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_testing.py b/tests/test_testing.py index cdaf5b41e6..c8caf063c2 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_util.py b/tests/test_util.py index 8771bb0515..98cc168f0e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/tuning/__init__.py b/tests/tuning/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/tuning/__init__.py +++ b/tests/tuning/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/tuning/test_scaling.py b/tests/tuning/test_scaling.py index d8f5428664..d3e0607149 100644 --- a/tests/tuning/test_scaling.py +++ b/tests/tuning/test_scaling.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/tuning/test_starting.py b/tests/tuning/test_starting.py index 8fd41a0522..0b19aa57ff 100644 --- a/tests/tuning/test_starting.py +++ b/tests/tuning/test_starting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/__init__.py b/tests/variational/__init__.py index ae0da7db23..1cfcbd166f 100644 --- a/tests/variational/__init__.py +++ b/tests/variational/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_approximations.py b/tests/variational/test_approximations.py index 7f088766fe..ab30e9bbe3 100644 --- a/tests/variational/test_approximations.py +++ b/tests/variational/test_approximations.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_callbacks.py b/tests/variational/test_callbacks.py index 2c314ee296..dd8c66737f 100644 --- a/tests/variational/test_callbacks.py +++ b/tests/variational/test_callbacks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 5fb4237e9a..10b824179d 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py index 6f3e715af7..84d118c581 100644 --- a/tests/variational/test_minibatch_rv.py +++ b/tests/variational/test_minibatch_rv.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 43ba772216..d692b30014 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/variational/test_updates.py b/tests/variational/test_updates.py index 9f591675bb..1ddd78b5b6 100644 --- a/tests/variational/test_updates.py +++ b/tests/variational/test_updates.py @@ -1,4 +1,4 @@ -# Copyright 2024 The PyMC Developers +# Copyright 2024 - present The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 5d6e5601d101254d116e273e7f16027e4a6f89a3 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 22 Dec 2024 15:37:20 +0100 Subject: [PATCH 13/25] Don't persist credentials This is an insecure default on GitHub that increases the chances of credential leakage. --- .github/workflows/devcontainer-docker-image.yml | 2 ++ .github/workflows/docker-image.yml | 2 ++ .github/workflows/mypy.yml | 2 ++ .github/workflows/pr-auto-label.yml | 2 ++ .github/workflows/tests.yml | 11 +++++++++++ 5 files changed, 19 insertions(+) diff --git a/.github/workflows/devcontainer-docker-image.yml b/.github/workflows/devcontainer-docker-image.yml index 5ed3559592..c9dc6bd937 100644 --- a/.github/workflows/devcontainer-docker-image.yml +++ b/.github/workflows/devcontainer-docker-image.yml @@ -24,6 +24,8 @@ jobs: steps: - name: Checkout source uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - name: Setup Docker buildx uses: docker/setup-buildx-action@v3.7.1 diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index bbbbd27ae9..5e66fe6f2d 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -14,6 +14,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - name: Login to Docker Hub uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index afa32a443f..e6ea6826f8 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -13,6 +13,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/environment-test.yml diff --git a/.github/workflows/pr-auto-label.yml b/.github/workflows/pr-auto-label.yml index 2dcb2dd3d2..252687b777 100644 --- a/.github/workflows/pr-auto-label.yml +++ b/.github/workflows/pr-auto-label.yml @@ -11,6 +11,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v2 + with: + persist-credentials: false - name: Sync labels with closing issues uses: wd60622/closing-labels@v0.0.3 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index be2444921d..268656f68b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,6 +34,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: fetch-depth: 0 + persist-credentials: false - uses: dorny/paths-filter@v3 id: changes with: @@ -144,6 +145,8 @@ jobs: shell: bash -leo pipefail {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/environment-test.yml @@ -194,6 +197,8 @@ jobs: shell: cmd /C call {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/windows-environment-test.yml @@ -253,6 +258,8 @@ jobs: shell: bash -leo pipefail {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/environment-test.yml @@ -297,6 +304,8 @@ jobs: shell: bash -leo pipefail {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/environment-jax.yml @@ -341,6 +350,8 @@ jobs: shell: cmd /C call {0} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: environment-file: conda-envs/windows-environment-test.yml From 0c044c04b21f8a504c39c6c268f876cf914e7de4 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 22 Dec 2024 18:05:45 +0100 Subject: [PATCH 14/25] Remove unnecessary checkout and permissions from pr-auto-label --- .github/workflows/pr-auto-label.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/pr-auto-label.yml b/.github/workflows/pr-auto-label.yml index 252687b777..21adc89470 100644 --- a/.github/workflows/pr-auto-label.yml +++ b/.github/workflows/pr-auto-label.yml @@ -5,14 +5,9 @@ on: jobs: sync: permissions: - contents: read pull-requests: write runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - persist-credentials: false - name: Sync labels with closing issues uses: wd60622/closing-labels@v0.0.3 with: From d7a5f94e2fa1941c8ab9db76708a7d70bfd93845 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 22 Dec 2024 18:59:18 +0100 Subject: [PATCH 15/25] Add exceptions for audited "dangerous triggers" --- .github/workflows/pr-auto-label.yml | 3 ++- .github/workflows/rtd-link-preview.yml | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pr-auto-label.yml b/.github/workflows/pr-auto-label.yml index 21adc89470..9879a8e55f 100644 --- a/.github/workflows/pr-auto-label.yml +++ b/.github/workflows/pr-auto-label.yml @@ -1,6 +1,7 @@ name: "Pull Request Labeler" on: -- pull_request_target +# The labeler doesn't execute any contributed code, so it should be fairly safe. +- pull_request_target # zizmor: ignore[dangerous-triggers] jobs: sync: diff --git a/.github/workflows/rtd-link-preview.yml b/.github/workflows/rtd-link-preview.yml index 626b410c38..62ba591070 100644 --- a/.github/workflows/rtd-link-preview.yml +++ b/.github/workflows/rtd-link-preview.yml @@ -1,15 +1,15 @@ name: Read the Docs Pull Request Preview on: - pull_request_target: + # See + pull_request_target: # zizmor: ignore[dangerous-triggers] types: - opened -permissions: - pull-requests: write - jobs: documentation-links: runs-on: ubuntu-latest + permissions: + pull-requests: write steps: - uses: readthedocs/actions/preview@v1 with: From 2012262a30c8b72c64effe5fe4cf26046e16e966 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 22 Dec 2024 19:11:18 +0100 Subject: [PATCH 16/25] Create zizmor workflow --- .github/workflows/zizmor.yml | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/zizmor.yml diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml new file mode 100644 index 0000000000..b747897eb8 --- /dev/null +++ b/.github/workflows/zizmor.yml @@ -0,0 +1,36 @@ +# https://github.com/woodruffw/zizmor +name: zizmor GHA analysis + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + +jobs: + zizmor: + name: zizmor latest via PyPI + runs-on: ubuntu-latest + permissions: + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + + - uses: hynek/setup-cached-uv@v2 + + - name: Run zizmor 🌈 + run: uvx zizmor --format sarif . > results.sarif + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload SARIF file + uses: github/codeql-action/upload-sarif@v3 + with: + # Path to SARIF file relative to the root of the repository + sarif_file: results.sarif + # Optional category for the results + # Used to differentiate multiple results for one commit + category: zizmor From 892c37ae89da2d7e5b3f49cb1431e83d49cb4f80 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 20 Jan 2025 17:48:47 +0000 Subject: [PATCH 17/25] Check for observed variables in the trace (#7641) --- pymc/sampling/forward.py | 16 +++++++++++++--- tests/sampling/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 0abe0d66b3..19957e0540 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -345,10 +345,13 @@ def draw( return [np.stack(v) for v in drawn_values] -def observed_dependent_deterministics(model: Model): +def observed_dependent_deterministics(model: Model, extra_observeds=None): """Find deterministics that depend directly on observed variables.""" + if extra_observeds is None: + extra_observeds = [] + deterministics = model.deterministics - observed_rvs = set(model.observed_RVs) + observed_rvs = set(model.observed_RVs + extra_observeds) blockers = model.basic_RVs return [ deterministic @@ -767,6 +770,7 @@ def sample_posterior_predictive( if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} idata: InferenceData | None = None + observed_data = None stacked_dims = None if isinstance(trace, InferenceData): _constant_data = getattr(trace, "constant_data", None) @@ -774,6 +778,7 @@ def sample_posterior_predictive( trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) idata = trace + observed_data = trace.get("observed_data", None) trace = trace["posterior"] if isinstance(trace, xarray.Dataset): trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) @@ -816,7 +821,12 @@ def sample_posterior_predictive( if var_names is not None: vars_ = [model[x] for x in var_names] else: - vars_ = model.observed_RVs + observed_dependent_deterministics(model) + observed_vars = model.observed_RVs + if observed_data is not None: + observed_vars += [ + model[x] for x in observed_data if x in model and x not in observed_vars + ] + vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars) vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index bee0d2f792..7bbcdc42b1 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -540,6 +540,24 @@ def test_normal_scalar_idata(self): ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws) + def test_external_trace_det(self): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + b = pm.Deterministic("b", a + 1) + trace = pm.sample(tune=50, draws=50, chains=1, compute_convergence_checks=False) + + # test that trace is used in ppc + with pm.Model() as model_ppc: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1) + c = pm.Deterministic("c", a + 1) + + ppc = pm.sample_posterior_predictive( + trace=trace, model=model_ppc, return_inferencedata=False + ) + assert list(ppc.keys()) == ["a", "c"] + def test_normal_vector(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) From 7a995a01b50dd876a3636663e930dc12769a7b39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 22:44:45 +0100 Subject: [PATCH 18/25] [pre-commit.ci] pre-commit autoupdate (#7653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.9.1 → v0.9.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.1...v0.9.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a3b1a7df6..1100f0e2fd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,7 +49,7 @@ repos: - --exclude=versioneer.py - --last-year-present - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.1 + rev: v0.9.2 hooks: - id: ruff args: [--fix, --show-fixes] From fa43eba8d682fbf4039bbd1c228c940c872e2d5d Mon Sep 17 00:00:00 2001 From: nataziel <114114079+nataziel@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:08:20 +1000 Subject: [PATCH 19/25] Use jaxified logp for initial point evaluation when sampling via Jax (#7610) * use jaxified logp for initial point evaluation when sampling via Jax * correcting initial point type hinting * refactor init_jitter inputs --------- Co-authored-by: Goose --- pymc/initial_point.py | 14 +++- pymc/sampling/jax.py | 189 +++++++++++++++++++++++++++--------------- pymc/sampling/mcmc.py | 26 +++--- 3 files changed, 147 insertions(+), 82 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index ba3a0ea85c..c276a5c496 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -26,6 +26,7 @@ from pymc.logprob.transforms import Transform from pymc.pytensorf import ( + SeedSequenceSeed, compile, find_rng_nodes, replace_rng_nodes, @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain( overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, chains: int, -) -> list[Callable]: +) -> list[Callable[[SeedSequenceSeed], PointType]]: """Create an initial point function for each chain, as defined by initvals. If a single initval dictionary is passed, the function is replicated for each @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain( Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + Returns + ------- + ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]] + list of functions that return initial points for each chain. + Raises ------ ValueError @@ -124,7 +130,7 @@ def make_initial_point_fn( jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = True, -) -> Callable: +) -> Callable[[SeedSequenceSeed], PointType]: """Create seeded function that computes initial values for all free model variables. Parameters @@ -138,6 +144,10 @@ def make_initial_point_fn( Initial value (strategies) to use instead of what's specified in `Model.initial_values`. return_transformed : bool If `True` the returned variables will correspond to transformed initial values. + + Returns + ------- + initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]] """ sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 03413667df..b2cbff9b68 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -18,6 +18,7 @@ from collections.abc import Callable, Sequence from datetime import datetime from functools import partial +from types import ModuleType from typing import Any, Literal import arviz as az @@ -28,6 +29,7 @@ from arviz.data.base import make_attrs from jax.lax import scan +from numpy.typing import ArrayLike from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -120,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl def get_jaxified_graph( inputs: list[TensorVariable] | None = None, outputs: list[TensorVariable] | None = None, -) -> list[TensorVariable]: +) -> Callable[[list[TensorVariable]], list[TensorVariable]]: """Compile a PyTensor graph into an optimized JAX function.""" graph = _replace_shared_variables(outputs) if outputs is not None else None @@ -143,13 +145,13 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: +def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) - def logp_fn_wrap(x): + def logp_fn_wrap(x: ArrayLike) -> jax.Array: return logp_fn(*x)[0] return logp_fn_wrap @@ -210,10 +212,16 @@ def _get_batched_jittered_initial_points( chains: int, initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, jitter: bool = True, jitter_max_retries: int = 10, ) -> np.ndarray | list[np.ndarray]: - """Get jittered initial point in format expected by NumPyro MCMC kernel. + """Get jittered initial point in format expected by Jax MCMC kernel. + + Parameters + ---------- + logp_fn : Callable[Sequence[np.ndarray]], np.ndarray] + Jaxified logp function Returns ------- @@ -221,12 +229,26 @@ def _get_batched_jittered_initial_points( list with one item per variable and number of chains as batch dimension. Each item has shape `(chains, *var.shape)` """ + if logp_fn is None: + eval_logp_initial_point = None + + else: + + def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array: + """Wrap logp_fn to conform to _init_jitter logic. + + Wraps jaxified logp function to accept a dict of + {model_variable: np.array} key:value pairs. + """ + return logp_fn(point.values()) + initial_points = _init_jitter( model, initvals, seeds=_get_seeds_per_chain(random_seed, chains), jitter=jitter, jitter_max_retries=jitter_max_retries, + logp_fn=eval_logp_initial_point, ) initial_points_values = [list(initial_point.values()) for initial_point in initial_points] if chains == 1: @@ -235,7 +257,7 @@ def _get_batched_jittered_initial_points( def _blackjax_inference_loop( - seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs + seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs ): import blackjax @@ -251,13 +273,13 @@ def _blackjax_inference_loop( adapt = blackjax.window_adaptation( algorithm=algorithm, - logdensity_fn=logprob_fn, + logdensity_fn=logp_fn, target_acceptance_rate=target_accept, adaptation_info_fn=get_filter_adapt_info_fn(), **adaptation_kwargs, ) (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) - kernel = algorithm(logprob_fn, **tuned_params).step + kernel = algorithm(logp_fn, **tuned_params).step def _one_step(state, xs): _, rng_key = xs @@ -288,67 +310,51 @@ def _sample_blackjax_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, - initial_points, - nuts_kwargs, -) -> az.InferenceData: + initial_points: np.ndarray | list[np.ndarray], + nuts_kwargs: dict[str, Any], + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - tune : int, default 1000 + model : Model + Model to sample from. The model needs to have free random variables. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. + tune : int Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - chains : int, default 4 + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, RandomState or Generator, optional + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. + progressbar : bool + Whether to show progressbar or not during sampling. + random_seed : int, RandomState or Generator Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. + initial_points : np.ndarray or list[np.ndarray] + Initial point(s) for sampler to begin sampling from. + nuts_kwargs : dict + Keyword arguments for the blackjax nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None + jaxified logp function. If not passed in it will be created anew. Returns ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + blackjax : ModuleType["blackjax"] """ import blackjax @@ -365,7 +371,8 @@ def _sample_blackjax_nuts( if chains == 1: initial_points = [np.stack(init_state) for init_state in zip(initial_points)] - logprob_fn = get_jaxified_logp(model) + if logp_fn is None: + logp_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) @@ -373,7 +380,7 @@ def _sample_blackjax_nuts( nuts_kwargs["progress_bar"] = progressbar get_posterior_samples = partial( _blackjax_inference_loop, - logprob_fn=logprob_fn, + logp_fn=logp_fn, tune=tune, draws=draws, target_accept=target_accept, @@ -385,7 +392,7 @@ def _sample_blackjax_nuts( # Adopted from arviz numpyro extractor -def _numpyro_stats_to_dict(posterior): +def _numpyro_stats_to_dict(posterior) -> dict[str, Any]: """Extract sample_stats from NumPyro posterior.""" rename_key = { "potential_energy": "lp", @@ -411,17 +418,58 @@ def _sample_numpyro_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, - initial_points, + initial_points: np.ndarray | list[np.ndarray], nuts_kwargs: dict[str, Any], -): + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: + """ + Draw samples from the posterior using the NUTS method from the ``numpyro`` library. + + Parameters + ---------- + model : Model + Model to sample from. The model needs to have free random variables. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. + tune : int + Number of iterations to tune. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number + specified in the ``draws`` argument. + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int + The number of chains to sample. + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. + progressbar : bool + Whether to show progressbar or not during sampling. + random_seed : int, RandomState or Generator + Random seed used by the sampling steps. + initial_points : np.ndarray or list[np.ndarray] + Initial point(s) for sampler to begin sampling from. + nuts_kwargs : dict + Keyword arguments for the underlying numpyro nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None + jaxified logp function. If not passed in it will be created anew. + + Returns + ------- + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + numpyro : ModuleType["numpyro"] + """ import numpyro from numpyro.infer import MCMC, NUTS - logp_fn = get_jaxified_logp(model, negative_logp=False) + if logp_fn is None: + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kwargs.setdefault("adapt_step_size", True) nuts_kwargs.setdefault("adapt_mass_matrix", True) @@ -479,7 +527,7 @@ def sample_jax_nuts( nuts_kwargs: dict | None = None, progressbar: bool = True, keep_untransformed: bool = False, - chain_method: str = "parallel", + chain_method: Literal["parallel", "vectorized"] = "parallel", postprocessing_backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] | None = None, postprocessing_chunks=None, @@ -525,7 +573,7 @@ def sample_jax_nuts( If True, display a progressbar while sampling keep_untransformed : bool, default False Include untransformed variables in the posterior samples. - chain_method : str, default "parallel" + chain_method : Literal["parallel", "vectorized"], default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None, @@ -589,6 +637,15 @@ def sample_jax_nuts( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) + if nuts_sampler == "numpyro": + sampler_fn = _sample_numpyro_nuts + logp_fn = get_jaxified_logp(model, negative_logp=False) + elif nuts_sampler == "blackjax": + sampler_fn = _sample_blackjax_nuts + logp_fn = get_jaxified_logp(model) + else: + raise ValueError(f"{nuts_sampler=} not recognized") + (random_seed,) = _get_seeds_per_chain(random_seed, 1) initial_points = _get_batched_jittered_initial_points( @@ -597,15 +654,9 @@ def sample_jax_nuts( initvals=initvals, random_seed=random_seed, jitter=jitter, + logp_fn=logp_fn, ) - if nuts_sampler == "numpyro": - sampler_fn = _sample_numpyro_nuts - elif nuts_sampler == "blackjax": - sampler_fn = _sample_blackjax_nuts - else: - raise ValueError(f"{nuts_sampler=} not recognized") - tic1 = datetime.now() raw_mcmc_samples, sample_stats, library = sampler_fn( model=model, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 758cb86448..6fb80284fd 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1428,7 +1428,7 @@ def _init_jitter( seeds: Sequence[int] | np.ndarray, jitter: bool, jitter_max_retries: int, - logp_dlogp_func=None, + logp_fn: Callable[[PointType], np.ndarray] | None = None, ) -> list[PointType]: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. @@ -1443,11 +1443,14 @@ def _init_jitter( Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). + logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray | jax.Array] | None + logp function that takes the output of initial point functions as input. + If None, will use the results of model.compile_logp(). Returns ------- - start : ``pymc.model.Point`` - Starting point for sampler + initial_points : list[dict[str, np.ndarray]] + List of starting points for the sampler """ ipfns = make_initial_point_fns_per_chain( model=model, @@ -1459,14 +1462,10 @@ def _init_jitter( if not jitter: return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] - model_logp_fn: Callable - if logp_dlogp_func is None: - model_logp_fn = model.compile_logp() + if logp_fn is None: + model_logp_fn: Callable[[PointType], np.ndarray] = model.compile_logp() else: - - def model_logp_fn(ip): - q, _ = DictToArrayBijection.map(ip) - return logp_dlogp_func([q], extra_vars={})[0] + model_logp_fn = logp_fn initial_points = [] for ipfn, seed in zip(ipfns, seeds): @@ -1591,13 +1590,18 @@ def init_nuts( logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) logp_dlogp_func.trust_input = True + + def model_logp_fn(ip: PointType) -> np.ndarray: + q, _ = DictToArrayBijection.map(ip) + return logp_dlogp_func([q], extra_vars={})[0] + initial_points = _init_jitter( model, initvals, seeds=random_seed_list, jitter="jitter" in init, jitter_max_retries=jitter_max_retries, - logp_dlogp_func=logp_dlogp_func, + logp_fn=model_logp_fn, ) apoints = [DictToArrayBijection.map(point) for point in initial_points] From 472da97098d5f42fdbe5267704af43b07b20ebd1 Mon Sep 17 00:00:00 2001 From: Adarsh Dubey <84132532+inclinedadarsh@users.noreply.github.com> Date: Mon, 27 Jan 2025 00:12:59 +0530 Subject: [PATCH 20/25] fix: deep copy nuts_sampler_kwarg to prevent `.pop` side effects (#7652) --- pymc/sampling/mcmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 6fb80284fd..4d7b9ff15a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -338,6 +338,7 @@ def _sample_external_nuts( UserWarning, ) compile_kwargs = {} + nuts_sampler_kwargs = nuts_sampler_kwargs.copy() for kwarg in ("backend", "gradient_backend"): if kwarg in nuts_sampler_kwargs: compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) From 0db176c9d7773fac5d682d9c75f0fd7cb9459ec1 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Mon, 27 Jan 2025 15:07:23 +0800 Subject: [PATCH 21/25] Show one progress bar per chain when sampling (#7634) * One progress bar per chain when samplings * Add guard against divide by zero when computing draws per second * No more purple * Step samplers are responsible for setting up progress bars * Fix typos * Add progressbar defaults to BlockedStep ABC * pre-commit * Only update NUTS divergence stats after tuning * Add `Elapsed` and `Remaining` columns * Remove green color when chain finishes * Create `ProgressManager` class to handle progress bars * Yield `stats` from `_iter_sample` * Use `ProgressManager` in `_sample_many` * pre-commit * Explicit case handling for `progressbar` argument * Allow all permutations of arguments to progressbar * Appease mypy * Add True case * Fix final count when `progress = "combined"` * Update docstrings * mypy + cleanup * Syntax error in typehint * Simplify progressbar choices, update docstring * Incorporate feedback * Be verbose with progressbar settings --- pymc/backends/__init__.py | 2 +- pymc/sampling/mcmc.py | 146 ++++++------- pymc/sampling/parallel.py | 47 ++--- pymc/step_methods/compound.py | 46 +++++ pymc/step_methods/hmc/nuts.py | 33 +++ pymc/step_methods/metropolis.py | 34 ++++ pymc/step_methods/slicer.py | 29 +++ pymc/util.py | 350 +++++++++++++++++++++++++++++++- 8 files changed, 576 insertions(+), 111 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index d3f7620882..eaa484a13f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -87,7 +87,7 @@ RunType: TypeAlias = Run HAS_MCB = True except ImportError: - TraceOrBackend = BaseTrace # type: ignore[misc] + TraceOrBackend = BaseTrace # type: ignore[assignment, misc] RunType = type(None) # type: ignore[assignment, misc] diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4d7b9ff15a..f2dfa6e9c2 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -36,8 +36,6 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs from pytensor.graph.basic import Variable -from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol @@ -67,7 +65,8 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - CustomProgress, + ProgressBarManager, + ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, @@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: else: varnames = ", ".join( [ - get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name + get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[misc] for v in s.vars ] ) @@ -425,7 +424,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -457,7 +456,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -489,8 +488,8 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, - progressbar_theme: Theme | None = default_progress_theme, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -540,11 +539,18 @@ def sample( A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed. We no longer support ``RandomState`` objects because their seeding mechanism does not allow easy spawning of new independent random streams that are needed by the step methods. - progressbar : bool, optional default=True - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - Only applicable to the pymc nuts sampler. + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step @@ -710,6 +716,10 @@ def sample( if isinstance(trace, list): raise ValueError("Please use `var_names` keyword argument for partial traces.") + # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and + # ADVI initialization expect just a bool. + progress_bool = bool(progressbar) + model = modelcontext(model) if not model.free_RVs: raise SamplingError( @@ -806,7 +816,7 @@ def joined_blas_limiter(): initvals=initvals, model=model, var_names=var_names, - progressbar=progressbar, + progressbar=progress_bool, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, nuts_sampler_kwargs=nuts_sampler_kwargs, @@ -825,7 +835,7 @@ def joined_blas_limiter(): n_init=n_init, model=model, random_seed=random_seed_list, - progressbar=progressbar, + progressbar=progress_bool, jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, @@ -1139,25 +1149,35 @@ def _sample_many( Step function """ initial_step_state = step.sampling_state - for i in range(chains): - step.sampling_state = initial_step_state - _sample( - draws=draws, - chain=i, - start=start[i], - step=step, - trace=traces[i], - rng=rngs[i], - callback=callback, - **kwargs, - ) + progress_manager = ProgressBarManager( + step_method=step, + chains=chains, + draws=draws - kwargs.get("tune", 0), + tune=kwargs.get("tune", 0), + progressbar=kwargs.get("progressbar", True), + progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme), + ) + + with progress_manager: + for i in range(chains): + step.sampling_state = initial_step_state + _sample( + draws=draws, + chain=i, + start=start[i], + step=step, + trace=traces[i], + rng=rngs[i], + callback=callback, + progress_manager=progress_manager, + **kwargs, + ) return def _sample( *, chain: int, - progressbar: bool, rng: np.random.Generator, start: PointType, draws: int, @@ -1165,8 +1185,8 @@ def _sample( trace: IBaseTrace, tune: int, model: Model | None = None, - progressbar_theme: Theme | None = default_progress_theme, callback=None, + progress_manager: ProgressBarManager, **kwargs, ) -> None: """Sample one chain (singleprocess). @@ -1177,27 +1197,23 @@ def _sample( ---------- chain : int Number of the chain that the samples will belong to. - progressbar : bool - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - random_seed : single random seed + random_seed : Generator + Single random seed start : dict Starting point in parameter space (or partial point) draws : int The number of samples to draw - step : function - Step function + step : Step + Step class instance used to generate samples. trace A chain backend to record draws and stats. tune : int Number of iterations to tune. - model : Model (optional if in ``with`` context) - progressbar_theme : Theme - Optional custom theme for the progress bar. + model : Model, optional + PyMC model. If None, the model is taken from the current context. + progress_manager: ProgressBarManager + Helper class used to handle progress bar styling and updates """ - skip_first = kwargs.get("skip_first", 0) - sampling_gen = _iter_sample( draws=draws, step=step, @@ -1209,32 +1225,19 @@ def _sample( rng=rng, callback=callback, ) - _pbar_data = {"chain": chain, "divergences": 0} - _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - - progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), - console=Console(theme=progressbar_theme), - disable=not progressbar, - ) + try: + for it, stats in enumerate(sampling_gen): + progress_manager.update( + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + ) - with progress: - try: - task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) - for it, diverging in enumerate(sampling_gen): - if it >= skip_first and diverging: - _pbar_data["divergences"] += 1 - progress.update(task, description=_desc.format(**_pbar_data), completed=it) - progress.update( - task, description=_desc.format(**_pbar_data), completed=draws, refresh=True + if not progress_manager.combined_progress or chain == progress_manager.chains - 1: + progress_manager.update( + chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False ) - except KeyboardInterrupt: - pass + + except KeyboardInterrupt: + pass def _iter_sample( @@ -1248,7 +1251,7 @@ def _iter_sample( rng: np.random.Generator, model: Model | None = None, callback: SamplingIteratorCallback | None = None, -) -> Iterator[bool]: +) -> Iterator[list[dict[str, Any]]]: """Sample one chain with a generator (singleprocess). Parameters @@ -1271,8 +1274,8 @@ def _iter_sample( Yields ------ - diverging : bool - Indicates if the draw is divergent. Only available with some samplers. + stats : list of dict + Dictionary of statistics returned by step sampler """ draws = int(draws) @@ -1294,22 +1297,25 @@ def _iter_sample( step.iter_count = 0 if i == tune: step.stop_tuning() + point, stats = step.step(point) trace.record(point, stats) log_warning_stats(stats) - diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") is True) + if callback is not None: callback( trace=trace, draw=Draw(chain, i == draws, i, i < tune, stats, point), ) - yield diverging + yield stats + except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) trace.close() raise + else: if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 3c2a8c9a36..af2106ce6f 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,8 +27,6 @@ import cloudpickle import numpy as np -from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits @@ -36,7 +34,7 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( - CustomProgress, + ProgressBarManager, RandomGeneratorState, default_progress_theme, get_state_from_generator, @@ -485,23 +483,14 @@ def __init__( self._max_active = cores self._in_context = False - - self._progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), - console=Console(theme=progressbar_theme), - disable=not progressbar, + self._progress = ProgressBarManager( + step_method=step_method, + chains=chains, + draws=draws, + tune=tune, + progressbar=progressbar, + progressbar_theme=progressbar_theme, ) - self._show_progress = progressbar - self._divergences = 0 - self._completed_draws = 0 - self._total_draws = chains * (draws + tune) - self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" - self._chains = chains def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -516,24 +505,13 @@ def __iter__(self): raise ValueError("Use ParallelSampler as context manager.") self._make_active() - with self._progress as progress: - task = progress.add_task( - self._desc.format(self), - completed=self._completed_draws, - total=self._total_draws, - ) - + with self._progress: while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw - self._completed_draws += 1 - if not tuning and stats and stats[0].get("diverging"): - self._divergences += 1 - progress.update( - task, - completed=self._completed_draws, - total=self._total_draws, - description=self._desc.format(self), + + self._progress.update( + chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats ) if is_last: @@ -541,7 +519,6 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index ff3f9c66a5..d07b070f0f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -181,6 +181,20 @@ def __new__(cls, *args, **kwargs): step.__newargs = (vars, *args), kwargs return step + @staticmethod + def _progressbar_config(n_chains=1): + columns = [] + stats = {} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + return stats + + return update_stats + # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): return self.__newargs @@ -297,6 +311,38 @@ def set_rng(self, rng: RandomGenerator): for method, _rng in zip(self.methods, _rngs): method.set_rng(_rng) + def _progressbar_config(self, n_chains=1): + from functools import reduce + + column_lists, stat_dict_list = zip( + *[method._progressbar_config(n_chains) for method in self.methods] + ) + flat_list = reduce(lambda left_list, right_list: left_list + right_list, column_lists) + + columns = [] + headers = [] + + for col in flat_list: + name = col.get_table_column().header + if name not in headers: + headers.append(name) + columns.append(col) + + stats = reduce(lambda left_dict, right_dict: left_dict | right_dict, stat_dict_list) + + return columns, stats + + def _make_update_stats_function(self): + update_fns = [method._make_update_stats_function() for method in self.methods] + + def update_stats(stats, step_stats, chain_idx): + for step_stat, update_fn in zip(step_stats, update_fns): + stats = update_fn(stats, step_stat, chain_idx) + + return stats + + return update_stats + def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: """Flatten a hierarchy of step methods to a list.""" diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index bbda728e80..18707c3592 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -20,6 +20,8 @@ import numpy as np from pytensor import config +from rich.progress import TextColumn +from rich.table import Column from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence @@ -229,6 +231,37 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), + TextColumn("{task.fields[tree_size]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "step_size": [0] * n_chains, + "tree_size": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + if not step_stats["tune"]: + stats["divergences"][chain_idx] += step_stats["diverging"] + + stats["step_size"][chain_idx] = step_stats["step_size"] + stats["tree_size"][chain_idx] = step_stats["tree_size"] + return stats + + return update_stats + # A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory") diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8e22218a13..70c650653d 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -24,6 +24,8 @@ from pytensor import tensor as pt from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV +from rich.progress import TextColumn +from rich.table import Column import pymc as pm @@ -325,6 +327,38 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: def competence(var, has_grad): return Competence.COMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), + TextColumn( + "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) + ), + ] + + stats = { + "tune": [True] * n_chains, + "scaling": [0] * n_chains, + "accept_rate": [0.0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["accept_rate"][chain_idx] = step_stats["accept"] + stats["scaling"][chain_idx] = step_stats["scaling"] + + return stats + + return update_stats + def tune(scale, acc_rate): """ diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index ecc7967614..9c10acfdf4 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -17,6 +17,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.blocking import RaveledVars, StatsType from pymc.initial_point import PointType from pymc.model import modelcontext @@ -195,3 +198,29 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.COMPATIBLE return Competence.INCOMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), + TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), + ] + + stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["nstep_out"][chain_idx] = step_stats["nstep_out"] + stats["nstep_in"][chain_idx] = step_stats["nstep_in"] + + return stats + + return update_stats diff --git a/pymc/util.py b/pymc/util.py index 8dc7d16804..979b3beebf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -17,9 +17,9 @@ import warnings from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import NewType, cast +from typing import TYPE_CHECKING, Literal, NewType, cast import arviz import cloudpickle @@ -30,11 +30,35 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad -from rich.progress import Progress +from rich.box import SIMPLE_HEAD +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.style import Style +from rich.table import Column, Table from rich.theme import Theme from pymc.exceptions import BlockModelAccessError +if TYPE_CHECKING: + from pymc.step_methods.compound import BlockedStep, CompoundStep + + +ProgressBarType = Literal[ + "combined", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", +] + def __getattr__(name): if name == "dataset_to_point_list": @@ -55,6 +79,8 @@ def __getattr__(name): { "bar.complete": "#1764f4", "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", } ) @@ -556,8 +582,10 @@ class CustomProgress(Progress): it's `True`. """ - def __init__(self, *args, **kwargs): - self.is_enabled = kwargs.get("disable", None) is not True + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + if self.is_enabled: super().__init__(*args, **kwargs) @@ -607,6 +635,318 @@ def update( ) return None + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class DivergenceBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a divergence.""" + + def __init__(self, *args, diverging_color="red", **kwargs): + from matplotlib.colors import to_rgb + + self.diverging_color = diverging_color + self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] + + super().__init__(*args, **kwargs) + + self.non_diverging_style = self.complete_style + self.non_diverging_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + divergences = task.fields.get("divergences", 0) + if isinstance(divergences, float | int) and divergences > 0: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + else: + self.complete_style = self.non_diverging_style + self.finished_style = self.non_diverging_finished_style + + +class ProgressBarManager: + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: "BlockedStep | CompoundStep", + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep or CompoundStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + + match progressbar: + case True: + self.combined_progress = False + self.full_stats = True + show_progress = True + case False: + self.combined_progress = False + self.full_stats = True + show_progress = False + case "combined": + self.combined_progress = True + self.full_stats = False + show_progress = True + case "split": + self.combined_progress = False + self.full_stats = False + show_progress = True + case "combined+stats" | "stats+combined": + self.combined_progress = True + self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": + self.combined_progress = False + self.full_stats = True + show_progress = True + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "one of 'combined', 'split', 'split+stats', or 'combined+stats." + ) + + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + + self.progress_stats = progress_stats + self.update_stats = step_method._make_update_stats_function() + + self._show_progress = show_progress + self.divergences = 0 + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + def compute_draw_speed(self, chain_idx, draws): + elapsed = self._progress.tasks[chain_idx].elapsed + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + speed, unit = self.compute_draw_speed(chain_idx, draw) + + if not tuning and stats and stats[0].get("diverging"): + self.divergences += 1 + + self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) + more_updates = ( + {stat: value[chain_idx] for stat, value in self.progress_stats.items()} + if self.full_stats + else {} + ) + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + **more_updates, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw, + **more_updates, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) + + +def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) From e0e751199319e68f376656e2477c1543606c49c7 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 25 Jan 2025 13:29:21 +0100 Subject: [PATCH 22/25] Ignore inner unused RNG inputs in `collect_default_updates` --- pymc/pytensorf.py | 53 +++++++++++++++++++++++++++-------------- tests/test_pytensorf.py | 14 +++++++++++ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 1f390b1771..eda2064821 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -855,35 +855,44 @@ def find_default_update(clients, rng: Variable) -> None | Variable: # Root case, RNG is not used elsewhere if not rng_clients: - return rng + return None if len(rng_clients) > 1: # Multiple clients are techincally fine if they are used in identical operations # We check if the default_update of each client would be the same - update, *other_updates = ( + all_updates = [ find_default_update( # Pass version of clients that includes only one the RNG clients at a time clients | {rng: [rng_client]}, rng, ) for rng_client in rng_clients - ) - if all(equal_computations([update], [other_update]) for other_update in other_updates): - return update - - warnings.warn( - f"RNG Variable {rng} has multiple distinct clients {rng_clients}, " - f"likely due to an inconsistent random graph. " - f"No default update will be returned.", - UserWarning, - ) - return None + ] + updates = [update for update in all_updates if update is not None] + if not updates: + return None + if len(updates) == 1: + return updates[0] + else: + update, *other_updates = updates + if all( + equal_computations([update], [other_update]) for other_update in other_updates + ): + return update + + warnings.warn( + f"RNG Variable {rng} has multiple distinct clients {rng_clients}, " + f"likely due to an inconsistent random graph. " + f"No default update will be returned.", + UserWarning, + ) + return None [client, _] = rng_clients[0] # RNG is an output of the function, this is not a problem if isinstance(client.op, Output): - return rng + return None # RNG is used by another operator, which should output an update for the RNG if isinstance(client.op, RandomVariable): @@ -912,18 +921,26 @@ def find_default_update(clients, rng: Variable) -> None | Variable: ) elif isinstance(client.op, OpFromGraph): try: - next_rng = collect_default_updates_inner_fgraph(client)[rng] - except (ValueError, KeyError): + next_rng = collect_default_updates_inner_fgraph(client).get(rng) + if next_rng is None: + # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning + return None + except ValueError as exc: raise ValueError( f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n" "You can use `pytensorf.collect_default_updates` and include those updates as outputs." - ) + ) from exc else: # We don't know how this RNG should be updated. The user should provide an update manually return None # Recurse until we find final update for RNG - return find_default_update(clients, next_rng) + nested_next_rng = find_default_update(clients, next_rng) + if nested_next_rng is None: + # There were no more uses of this next_rng + return next_rng + else: + return nested_next_rng if inputs is None: inputs = [] diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 0ea18dabe3..c434f1a9c7 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -619,6 +619,20 @@ def test_op_from_graph_updates(self): fn = compile([], x, random_seed=1) assert not (set(fn()) & set(fn())) + def test_unused_ofg_rng(self): + rng = pytensor.shared(np.random.default_rng()) + next_rng, x = pt.random.normal(rng=rng).owner.outputs + ofg1 = OpFromGraph([rng], [next_rng, x]) + ofg2 = OpFromGraph([rng, x], [x + 1]) + + next_rng, x = ofg1(rng) + y = ofg2(rng, x) + + # In all these cases the update should be the same + assert collect_default_updates([x]) == {rng: next_rng} + assert collect_default_updates([y]) == {rng: next_rng} + assert collect_default_updates([x, y]) == {rng: next_rng} + def test_replace_rng_nodes(): rng = pytensor.shared(np.random.default_rng()) From 268e13bde3e4863370e3b418e37f63023c123b20 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 29 Jan 2025 18:08:30 +0100 Subject: [PATCH 23/25] remove the Futurewarning in the test --- tests/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_data.py b/tests/test_data.py index 5d370d02c0..0906ab8434 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -375,7 +375,7 @@ def test_implicit_coords_series(self, seeded_test): pd = pytest.importorskip("pandas") ser_sales = pd.Series( data=np.random.randint(low=0, high=30, size=22), - index=pd.date_range(start="2020-05-01", periods=22, freq="24H", name="date"), + index=pd.date_range(start="2020-05-01", periods=22, freq="24h", name="date"), name="sales", ) with pm.Model() as pmodel: From 3b6e3516301e09bf5711870ca974f3d7a28e2417 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 1 Feb 2025 13:05:31 -0600 Subject: [PATCH 24/25] Probability distributions guide update (#7671) --- .../guides/Probability_Distributions.rst | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/docs/source/guides/Probability_Distributions.rst b/docs/source/guides/Probability_Distributions.rst index 81530746cd..ae5abdb922 100644 --- a/docs/source/guides/Probability_Distributions.rst +++ b/docs/source/guides/Probability_Distributions.rst @@ -29,21 +29,32 @@ A variable requires at least a ``name`` argument, and zero or more model paramet Probability distributions are all subclasses of ``Distribution``, which in turn has two major subclasses: ``Discrete`` and ``Continuous``. In terms of data types, a ``Continuous`` random variable is given whichever floating point type is defined by ``pytensor.config.floatX``, while ``Discrete`` variables are given ``int16`` types when ``pytensor.config.floatX`` is ``float32``, and ``int64`` otherwise. -All distributions in ``pm.distributions`` will have two important methods: ``random()`` and ``logp()`` with the following signatures: +All distributions in ``pm.distributions`` are associated with two key functions: + +1. ``logp(dist, value)`` - Calculates log-probability at given value +2. ``draw(dist, size=...)`` - Generates random samples + +For example, with a normal distribution: :: - class SomeDistribution(Continuous): + with pm.Model(): + x = pm.Normal('x', mu=0, sigma=1) + + # Calculate log-probability + log_prob = pm.logp(x, 0.5) + + # Generate samples + samples = pm.draw(x, size=100) - def random(self, point=None, size=None): - ... - return random_samples +Custom distributions using ``CustomDist`` should provide logp via the ``dist`` parameter: + +:: - def logp(self, value): - ... - return total_log_prob + def custom_logp(value, mu): + return -0.5 * (value - mu)**2 -PyMC expects the ``logp()`` method to return a log-probability evaluated at the passed ``value`` argument. This method is used internally by all of the inference methods to calculate the model log-probability that is used for fitting models. The ``random()`` method is used to simulate values from the variable, and is used internally for posterior predictive checks. + custom_dist = pm.CustomDist('custom', dist=custom_logp, mu=0) Custom distributions @@ -58,7 +69,7 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\ \lambda \exp(-\lambda t), \text{if c=0} \end{array} \right. -Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability. +Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``CustomDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability. For the exponential survival function, this is: @@ -67,7 +78,7 @@ For the exponential survival function, this is: def logp(value, t, lam): return (value * log(lam) - lam * t).sum() - exp_surv = pm.DensityDist('exp_surv', t, lam, logp=logp, observed=failure) + exp_surv = pm.CustomDist('exp_surv', dist=logp, t=t, lam=lam, observed=failure) Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument. @@ -98,10 +109,10 @@ This allows for probabilities to be calculated and random numbers to be drawn. :: - >>> y.logp(4).eval() + >>> pm.logp(y, 4).eval() array(-1.5843639373779297, dtype=float32) - >>> y.random(size=3) + >>> pm.draw(y, size=3) array([5, 4, 3]) From d7d2be275caa43570ba28e667d9f2a1186e743e4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 6 Feb 2025 18:19:42 +0800 Subject: [PATCH 25/25] bump pytensor version dependency --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index de0572e0a2..546e1277d4 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index c399a3e24a..fd035676c3 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 39deb8a41a..6bde602133 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -21,7 +21,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 79c57a44c6..20f0478998 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,7 +14,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index bbcba9149f..503cf125b2 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 399fab811b..5136d997a3 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -16,7 +16,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.26.2,<2.27 +- pytensor>=2.26.2,<2.28 - python-graphviz - networkx - rich>=13.7.1 diff --git a/requirements-dev.txt b/requirements-dev.txt index e7e3644aae..b0ef92f505 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.26.2,<2.27 +pytensor>=2.26.2,<2.28 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index 28c9456b5e..c1f82bb4dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.25.0 pandas>=0.24.0 -pytensor>=2.26.1,<2.27 +pytensor>=2.26.1,<2.28 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0