diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9b4b000c7..efbb457f74 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,7 @@ jobs: tests/distributions/test_shape_utils.py tests/distributions/test_mixture.py tests/test_testing.py + tests/test_progress_bar.py - | tests/distributions/test_continuous.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6148c18564..0d5dfec0b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,7 +49,7 @@ repos: - --exclude=versioneer.py - --last-year-present - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.11.13 hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/README.rst b/README.rst index 4af51b652b..cec9cf4d56 100644 --- a/README.rst +++ b/README.rst @@ -254,6 +254,7 @@ Domain specific - `Exoplanet `__: a toolkit for modeling of transit and/or radial velocity observations of exoplanets and other astronomical time series. - `beat `__: Bayesian Earthquake Analysis Tool. - `CausalPy `__: A package focussing on causal inference in quasi-experimental settings. +- `PyMC-Marketing `__: Bayesian marketing toolbox for marketing mix modeling, customer lifetime value, and more. Please contact us if your software is not listed here. diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index 5030e7bacf..fcf78c6991 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -11,7 +11,7 @@ dependencies: - cloudpickle - zarr>=2.5.0,<3 - numba -- nutpie >= 0.13.4 +- nutpie >= 0.15.1 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index e161b66e1f..a49e0568ae 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -40,6 +40,6 @@ dependencies: - mypy=1.15.0 - types-cachetools - pip: - - git+https://github.com/pymc-devs/pymc-sphinx-theme + - pymc-sphinx-theme>=0.16.0 - numdifftools>=0.9.40 - mcbackend>=0.4.0 diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 5a7caa0c73..14a39c3761 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -14,6 +14,7 @@ This submodule contains functions for MCMC and forward sampling. sample_posterior_predictive draw compute_deterministics + vectorize_over_posterior init_nuts sampling.jax.sample_blackjax_nuts sampling.jax.sample_numpyro_nuts diff --git a/pymc/__init__.py b/pymc/__init__.py index 684feac11f..69a29c97e8 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -66,7 +66,7 @@ def __set_compiler_flags(): ) from pymc.model.core import * from pymc.model.transform.conditioning import do, observe -from pymc.model_graph import model_to_graphviz, model_to_networkx +from pymc.model_graph import model_to_graphviz, model_to_mermaid, model_to_networkx from pymc.plots import * from pymc.printing import * from pymc.pytensorf import * diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index f0f0eec963..71f08da826 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -39,8 +39,9 @@ import pymc from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import CustomProgress, default_progress_theme, get_default_varnames +from pymc.util import get_default_varnames if TYPE_CHECKING: from pymc.backends.base import MultiTrace diff --git a/pymc/data.py b/pymc/data.py index 4b3538a340..507f547e5b 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -14,7 +14,6 @@ import io import urllib.request -import warnings from collections.abc import Sequence from copy import copy @@ -40,10 +39,8 @@ from pymc.vartypes import isgenerator __all__ = [ - "ConstantData", "Data", "Minibatch", - "MutableData", "get_data", ] BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}" @@ -218,66 +215,6 @@ def determine_coords( return coords, new_dims -def ConstantData( - name: str, - value, - *, - dims: Sequence[str] | None = None, - coords: dict[str, Sequence | np.ndarray] | None = None, - infer_dims_and_coords=False, - **kwargs, -) -> TensorConstant: - """Alias for ``pm.Data``. - - Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model. - For more information, please reference :class:`pymc.Data`. - """ - warnings.warn( - "ConstantData is deprecated. All Data variables are now mutable. Use Data instead.", - FutureWarning, - ) - - var = Data( - name, - value, - dims=dims, - coords=coords, - infer_dims_and_coords=infer_dims_and_coords, - **kwargs, - ) - return cast(TensorConstant, var) - - -def MutableData( - name: str, - value, - *, - dims: Sequence[str] | None = None, - coords: dict[str, Sequence | np.ndarray] | None = None, - infer_dims_and_coords=False, - **kwargs, -) -> SharedVariable: - """Alias for ``pm.Data``. - - Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable` - with the model. For more information, please reference :class:`pymc.Data`. - """ - warnings.warn( - "MutableData is deprecated. All Data variables are now mutable. Use Data instead.", - FutureWarning, - ) - - var = Data( - name, - value, - dims=dims, - coords=coords, - infer_dims_and_coords=infer_dims_and_coords, - **kwargs, - ) - return cast(SharedVariable, var) - - def Data( name: str, value, @@ -285,7 +222,6 @@ def Data( dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, - mutable: bool | None = None, **kwargs, ) -> SharedVariable | TensorConstant: """Create a data container that registers a data variable with the model. @@ -380,11 +316,6 @@ def Data( "Pass them directly to `observed` if you want to trigger auto-imputation" ) - if mutable is not None: - warnings.warn( - "Data is now always mutable. Specifying the `mutable` kwarg will raise an error in a future release", - FutureWarning, - ) x = pytensor.shared(arr, name, **kwargs) if isinstance(dims, str): diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index 86aba12043..476d5bc41f 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import re -import warnings from collections.abc import Callable, Sequence @@ -715,13 +714,6 @@ def __new__( ) dist_params = cls.parse_dist_params(dist_params) cls.check_valid_dist_random(dist, random, dist_params) - moment = kwargs.pop("moment", None) - if moment is not None: - warnings.warn( - "`moment` argument is deprecated. Use `support_point` instead.", - FutureWarning, - ) - support_point = moment if dist is not None: kwargs.setdefault("class_name", f"CustomDist_{name}") return _CustomSymbolicDist( diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 7587edceac..b1d188d461 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1155,29 +1155,58 @@ def support_point(rv, size, p): mode = pt.full(size, mode) return mode - def logp(value, p): - k = pt.shape(p)[-1] - value_clip = pt.clip(value, 0, k - 1) + @staticmethod + def _safe_index_value_p(value, p): + # Find the probabily of the given value by indexing in p, + # after handling broadcasting and invalid values. # In the standard case p has one more dimension than value dim_diff = p.type.ndim - value.type.ndim if dim_diff > 1: # p brodacasts implicitly beyond value - value_clip = pt.shape_padleft(value_clip, dim_diff - 1) + value = pt.shape_padleft(value, dim_diff - 1) elif dim_diff < 1: # value broadcasts implicitly beyond p p = pt.shape_padleft(p, 1 - dim_diff) - a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) + k = pt.shape(p)[-1] + value_clip = pt.clip(value, 0, k - 1).astype(int) + return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) - res = pt.switch( + def logp(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p) + + value_p = pt.switch( pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)), -np.inf, - a, + safe_value_p, ) return check_parameters( - res, + value_p, + 0 <= p, + p <= 1, + pt.isclose(pt.sum(p, axis=-1), 1), + msg="0 <= p <=1, sum(p) = 1", + ) + + def logcdf(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1)) + + value_p = pt.switch( + pt.lt(value, 0), + -np.inf, + pt.switch( + pt.gt(value, k - 1), + 0, + safe_value_p, + ), + ) + + return check_parameters( + value_p, 0 <= p, p <= 1, pt.isclose(pt.sum(p, axis=-1), 1), diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index 3f675406f4..55efcb3b3c 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -18,8 +18,6 @@ @author: johnsalvatier """ -import warnings - from collections.abc import Iterable from functools import partial @@ -419,12 +417,3 @@ def log_i0(x): + 11025.0 / (98304.0 * x**4.0) ), ) - - -def incomplete_beta(a, b, value): - warnings.warn( - "incomplete_beta has been deprecated. Use pytensor.tensor.betainc instead.", - FutureWarning, - stacklevel=2, - ) - return pt.betainc(a, b, value) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 3f080b7e5b..f8712c51e5 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -60,7 +60,7 @@ convert_observed_data, floatX, ) -from pymc.util import UNSET, _add_future_warning_tag +from pymc.util import UNSET from pymc.vartypes import continuous_types, string_types __all__ = [ @@ -571,10 +571,7 @@ def dist( ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp) - rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) - - _add_future_warning_tag(rv_out) - return rv_out + return cls.rv_op(*dist_params, size=create_size, **kwargs) @node_rewriter([SymbolicRandomVariable]) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 8afc5b2f1e..3cfda1a837 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -56,10 +56,8 @@ ] -class MarginalMixtureRV(SymbolicRandomVariable): - """A placeholder used to specify a distribution for a mixture sub-graph.""" - - _print_name = ("MarginalMixture", "\\operatorname{MarginalMixture}") +class _BaseMixtureRV(SymbolicRandomVariable): + """Base class SymbolicRandomVariable for Mixture and Hurdle RVs.""" @classmethod def rv_op(cls, weights, *components, size=None): @@ -139,7 +137,7 @@ def rv_op(cls, weights, *components, size=None): comps_s = ",".join(f"({s})" for _ in components) extended_signature = f"[rng],(w),{comps_s}->[rng],({s})" - return MarginalMixtureRV( + return cls( inputs=[mix_indexes_rng, weights, *components], outputs=[mix_indexes_rng_next, mix_out], extended_signature=extended_signature, @@ -161,117 +159,8 @@ def update(self, node: Apply): return {node.inputs[0]: node.outputs[0]} -class Mixture(Distribution): - R""" - Mixture distribution. - - Often used to model subpopulation heterogeneity - - .. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) - - ======== ============================================ - Support :math:`\cup_{i = 1}^n \textrm{support}(f_i)` - Mean :math:`\sum_{i = 1}^n w_i \mu_i` - ======== ============================================ - - Parameters - ---------- - w : tensor_like of float - w >= 0 and w <= 1 - the mixture weights - comp_dists : iterable of unnamed distributions or single batched distribution - Distributions should be created via the `.dist()` API. If a single distribution - is passed, the last size dimension (not shape) determines the number of mixture - components (e.g. `pm.Poisson.dist(..., size=components)`) - :math:`f_1, \ldots, f_n` - - .. warning:: comp_dists will be cloned, rendering them independent of the ones passed as input. - - - Examples - -------- - .. code-block:: python - - # Mixture of 2 Poisson variables - with pm.Model() as model: - w = pm.Dirichlet("w", a=np.array([1, 1])) # 2 mixture weights - - lam1 = pm.Exponential("lam1", lam=1) - lam2 = pm.Exponential("lam2", lam=1) - - # As we just need the logp, rather than add a RV to the model, we need to call `.dist()` - # These two forms are equivalent, but the second benefits from vectorization - components = [ - pm.Poisson.dist(mu=lam1), - pm.Poisson.dist(mu=lam2), - ] - # `shape=(2,)` indicates 2 mixture components - components = pm.Poisson.dist(mu=pm.math.stack([lam1, lam2]), shape=(2,)) - - like = pm.Mixture("like", w=w, comp_dists=components, observed=data) - - - .. code-block:: python - - # Mixture of Normal and StudentT variables - with pm.Model() as model: - w = pm.Dirichlet("w", a=np.array([1, 1])) # 2 mixture weights - - mu = pm.Normal("mu", 0, 1) - - components = [ - pm.Normal.dist(mu=mu, sigma=1), - pm.StudentT.dist(nu=4, mu=mu, sigma=1), - ] - - like = pm.Mixture("like", w=w, comp_dists=components, observed=data) - - - .. code-block:: python - - # Mixture of (5 x 3) Normal variables - with pm.Model() as model: - # w is a stack of 5 independent size 3 weight vectors - # If shape was `(3,)`, the weights would be shared across the 5 replication dimensions - w = pm.Dirichlet("w", a=np.ones(3), shape=(5, 3)) - - # Each of the 3 mixture components has an independent mean - mu = pm.Normal("mu", mu=np.arange(3), sigma=1, shape=3) - - # These two forms are equivalent, but the second benefits from vectorization - components = [ - pm.Normal.dist(mu=mu[0], sigma=1, shape=(5,)), - pm.Normal.dist(mu=mu[1], sigma=1, shape=(5,)), - pm.Normal.dist(mu=mu[2], sigma=1, shape=(5,)), - ] - components = pm.Normal.dist(mu=mu, sigma=1, shape=(5, 3)) - - # The mixture is an array of 5 elements - # Each element can be thought of as an independent scalar mixture of 3 - # components with different means - like = pm.Mixture("like", w=w, comp_dists=components, observed=data) - - - .. code-block:: python - - # Mixture of 2 Dirichlet variables - with pm.Model() as model: - w = pm.Dirichlet("w", a=np.ones(2)) # 2 mixture weights - - # These two forms are equivalent, but the second benefits from vectorization - components = [ - pm.Dirichlet.dist(a=[1, 10, 100], shape=(3,)), - pm.Dirichlet.dist(a=[100, 10, 1], shape=(3,)), - ] - components = pm.Dirichlet.dist(a=[[1, 10, 100], [100, 10, 1]], shape=(2, 3)) - - # The mixture is an array of 3 elements - # Each element comes from only one of the two core Dirichlet components - like = pm.Mixture("like", w=w, comp_dists=components, observed=data) - """ - - rv_type = MarginalMixtureRV - rv_op = MarginalMixtureRV.rv_op +class _BaseMixtureDistribution(Distribution): + """Base class distribution for Mixture and Hurdle distributions.""" @classmethod def dist(cls, w, comp_dists, **kwargs): @@ -298,8 +187,6 @@ def dist(cls, w, comp_dists, **kwargs): # Check that components are not associated with a registered variable in the model components_ndim_supp = set() for dist in comp_dists: - # TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them - # and resize them if not isinstance(dist, TensorVariable) or not isinstance( dist.owner.op, RandomVariable | SymbolicRandomVariable ): @@ -318,8 +205,8 @@ def dist(cls, w, comp_dists, **kwargs): return super().dist([w, *comp_dists], **kwargs) -@_change_dist_size.register(MarginalMixtureRV) -def change_marginal_mixture_size(op, dist, new_size, expand=False): +@_change_dist_size.register(_BaseMixtureRV) +def change_mixture_size(op, dist, new_size, expand=False): rng, weights, *components = dist.owner.inputs if expand: @@ -333,39 +220,32 @@ def change_marginal_mixture_size(op, dist, new_size, expand=False): old_size = components[0].shape[:size_dims] new_size = tuple(new_size) + tuple(old_size) - return Mixture.rv_op(weights, *components, size=new_size) + return op.rv_op(weights, *components, size=new_size) -@_logprob.register(MarginalMixtureRV) -def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): - (value,) = values +@_support_point.register(_BaseMixtureRV) +def mixture_support_point(op, rv, rng, weights, *components): + ndim_supp = components[0].owner.op.ndim_supp + weights = pt.shape_padright(weights, ndim_supp) + mix_axis = -ndim_supp - 1 - # single component if len(components) == 1: - # Need to broadcast value across mixture axis - mix_axis = -components[0].owner.op.ndim_supp - 1 - components_logp = logp(components[0], pt.expand_dims(value, mix_axis)) + support_point_components = support_point(components[0]) + else: - components_logp = pt.stack( - [logp(component, value) for component in components], - axis=-1, + support_point_components = pt.stack( + [support_point(component) for component in components], + axis=mix_axis, ) - mix_logp = pt.logsumexp(pt.log(weights) + components_logp, axis=-1) - - mix_logp = check_parameters( - mix_logp, - 0 <= weights, - weights <= 1, - pt.isclose(pt.sum(weights, axis=-1), 1), - msg="0 <= weights <= 1, sum(weights) == 1", - ) - - return mix_logp + mix_support_point = pt.sum(weights * support_point_components, axis=mix_axis) + if components[0].dtype in discrete_types: + mix_support_point = pt.round(mix_support_point) + return mix_support_point -@_logcdf.register(MarginalMixtureRV) -def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): +@_logcdf.register(_BaseMixtureRV) +def mixture_logcdf(op, value, rng, weights, *components, **kwargs): # single component if len(components) == 1: # Need to broadcast value across mixture axis @@ -390,27 +270,6 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): return mix_logcdf -@_support_point.register(MarginalMixtureRV) -def marginal_mixture_support_point(op, rv, rng, weights, *components): - ndim_supp = components[0].owner.op.ndim_supp - weights = pt.shape_padright(weights, ndim_supp) - mix_axis = -ndim_supp - 1 - - if len(components) == 1: - support_point_components = support_point(components[0]) - - else: - support_point_components = pt.stack( - [support_point(component) for component in components], - axis=mix_axis, - ) - - mix_support_point = pt.sum(weights * support_point_components, axis=mix_axis) - if components[0].dtype in discrete_types: - mix_support_point = pt.round(mix_support_point) - return mix_support_point - - # List of transforms that can be used by Mixture, either because they do not require # special handling or because we have custom logic to enable them. If new default # transforms are implemented, this list and function should be updated @@ -431,8 +290,8 @@ class MixtureTransformWarning(UserWarning): pass -@_default_transform.register(MarginalMixtureRV) -def marginal_mixture_default_transform(op, rv): +@_default_transform.register(_BaseMixtureRV) +def mixture_default_transform(op, rv): def transform_warning(): warnings.warn( f"No safe default transform found for Mixture distribution {rv}. This can " @@ -491,6 +350,151 @@ def mixture_args_fn(rng, weights, *components): return default_transform +class MixtureRV(_BaseMixtureRV): + _print_name = ("Mixture", "\\operatorname{Mixture}") + + +class Mixture(_BaseMixtureDistribution): + R""" + Mixture distribution. + + Often used to model subpopulation heterogeneity + + .. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) + + ======== ============================================ + Support :math:`\cup_{i = 1}^n \textrm{support}(f_i)` + Mean :math:`\sum_{i = 1}^n w_i \mu_i` + ======== ============================================ + + Parameters + ---------- + w : tensor_like of float + w >= 0 and w <= 1 + the mixture weights + comp_dists : iterable of unnamed distributions or single batched distribution + Distributions should be created via the `.dist()` API. If a single distribution + is passed, the last size dimension (not shape) determines the number of mixture + components (e.g. `pm.Poisson.dist(..., size=components)`) + :math:`f_1, \ldots, f_n` + + .. warning:: comp_dists will be cloned, rendering them independent of the ones passed as input. + + + Examples + -------- + .. code-block:: python + + # Mixture of 2 Poisson variables + with pm.Model() as model: + w = pm.Dirichlet("w", a=np.array([1, 1])) # 2 mixture weights + + lam1 = pm.Exponential("lam1", lam=1) + lam2 = pm.Exponential("lam2", lam=1) + + # As we just need the logp, rather than add a RV to the model, we need to call `.dist()` + # These two forms are equivalent, but the second benefits from vectorization + components = [ + pm.Poisson.dist(mu=lam1), + pm.Poisson.dist(mu=lam2), + ] + # `shape=(2,)` indicates 2 mixture components + components = pm.Poisson.dist(mu=pm.math.stack([lam1, lam2]), shape=(2,)) + + like = pm.Mixture("like", w=w, comp_dists=components, observed=data) + + + .. code-block:: python + + # Mixture of Normal and StudentT variables + with pm.Model() as model: + w = pm.Dirichlet("w", a=np.array([1, 1])) # 2 mixture weights + + mu = pm.Normal("mu", 0, 1) + + components = [ + pm.Normal.dist(mu=mu, sigma=1), + pm.StudentT.dist(nu=4, mu=mu, sigma=1), + ] + + like = pm.Mixture("like", w=w, comp_dists=components, observed=data) + + + .. code-block:: python + + # Mixture of (5 x 3) Normal variables + with pm.Model() as model: + # w is a stack of 5 independent size 3 weight vectors + # If shape was `(3,)`, the weights would be shared across the 5 replication dimensions + w = pm.Dirichlet("w", a=np.ones(3), shape=(5, 3)) + + # Each of the 3 mixture components has an independent mean + mu = pm.Normal("mu", mu=np.arange(3), sigma=1, shape=3) + + # These two forms are equivalent, but the second benefits from vectorization + components = [ + pm.Normal.dist(mu=mu[0], sigma=1, shape=(5,)), + pm.Normal.dist(mu=mu[1], sigma=1, shape=(5,)), + pm.Normal.dist(mu=mu[2], sigma=1, shape=(5,)), + ] + components = pm.Normal.dist(mu=mu, sigma=1, shape=(5, 3)) + + # The mixture is an array of 5 elements + # Each element can be thought of as an independent scalar mixture of 3 + # components with different means + like = pm.Mixture("like", w=w, comp_dists=components, observed=data) + + + .. code-block:: python + + # Mixture of 2 Dirichlet variables + with pm.Model() as model: + w = pm.Dirichlet("w", a=np.ones(2)) # 2 mixture weights + + # These two forms are equivalent, but the second benefits from vectorization + components = [ + pm.Dirichlet.dist(a=[1, 10, 100], shape=(3,)), + pm.Dirichlet.dist(a=[100, 10, 1], shape=(3,)), + ] + components = pm.Dirichlet.dist(a=[[1, 10, 100], [100, 10, 1]], shape=(2, 3)) + + # The mixture is an array of 3 elements + # Each element comes from only one of the two core Dirichlet components + like = pm.Mixture("like", w=w, comp_dists=components, observed=data) + """ + + rv_type = MixtureRV + rv_op = MixtureRV.rv_op + + +@_logprob.register(MixtureRV) +def mixture_logprob(op, values, rng, weights, *components, **kwargs): + (value,) = values + + # single component + if len(components) == 1: + # Need to broadcast value across mixture axis + mix_axis = -components[0].owner.op.ndim_supp - 1 + components_logp = logp(components[0], pt.expand_dims(value, mix_axis)) + else: + components_logp = pt.stack( + [logp(component, value) for component in components], + axis=-1, + ) + + mix_logp = pt.logsumexp(pt.log(weights) + components_logp, axis=-1) + + mix_logp = check_parameters( + mix_logp, + 0 <= weights, + weights <= 1, + pt.isclose(pt.sum(weights, axis=-1), 1), + msg="0 <= weights <= 1, sum(weights) == 1", + ) + + return mix_logp + + class NormalMixture: R""" Normal mixture distribution. @@ -799,34 +803,65 @@ def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs): ) -def _hurdle_mixture(*, name, nonzero_p, nonzero_dist, dtype, max_n_steps=10_000, **kwargs): - """Create a hurdle mixtures (helper function). +class _HurdleRV(_BaseMixtureRV): + _print_name = ("Hurdle", "\\operatorname{Hurdle}") - If name is `None`, this function returns an unregistered variable - In hurdle models, the zeros come from a completely different process than the rest of the data. - In other words, the zeros are not inflated, they come from a different process. - """ - if dtype == "float": - zero = 0.0 - lower = np.finfo(pytensor.config.floatX).eps - elif dtype == "int": - zero = 0 - lower = 1 - else: - raise ValueError("dtype must be 'float' or 'int'") +class _Hurdle(_BaseMixtureDistribution): + rv_type = _HurdleRV + rv_op = _HurdleRV.rv_op - nonzero_p = pt.as_tensor_variable(nonzero_p) - weights = pt.stack([1 - nonzero_p, nonzero_p], axis=-1) - comp_dists = [ - DiracDelta.dist(zero), - Truncated.dist(nonzero_dist, lower=lower, max_n_steps=max_n_steps), - ] + @classmethod + def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs): + """Create a hurdle mixture (helper function). + + If name is `None`, this function returns an unregistered variable + + In hurdle models, the zeros come from a completely different process than the rest of the data. + In other words, the zeros are not inflated, they come from a different process. + + Note: this is invalid for discrete nonzero distributions with mass below 0, as we simply truncate[lower=1]. + """ + dtype = nonzero_dist.dtype + + if dtype.startswith("int"): + # Need to truncate the distribution to exclude zero. + # Continuous distributions have "zero" mass at zero (and anywhere else), so can be used as is. + nonzero_dist = Truncated.dist(nonzero_dist, lower=1, max_n_steps=max_n_steps) + elif not dtype.startswith("float"): + raise ValueError(f"nonzero_dist dtype must be 'float' or 'int', got {dtype}") + + nonzero_p = pt.as_tensor_variable(nonzero_p) + weights = pt.stack([1 - nonzero_p, nonzero_p], axis=-1) + comp_dists = [ + DiracDelta.dist(np.asarray(0, dtype=dtype)), + nonzero_dist, + ] + + if name is not None: + return cls(name, weights, comp_dists, **kwargs) + else: + return cls.dist(weights, comp_dists, **kwargs) - if name is not None: - return Mixture(name, weights, comp_dists, **kwargs) - else: - return Mixture.dist(weights, comp_dists, **kwargs) + +@_logprob.register(_HurdleRV) +def marginal_hurdle_logprob(op, values, rng, weights, zero_dist, dist, **kwargs): + (value,) = values + + psi = weights[..., 1] + + hurdle_logp = pt.where( + pt.eq(value, 0), + pt.log(1 - psi), + pt.log(psi) + logp(dist, value), + ) + + return check_parameters( + hurdle_logp, + 0 <= psi, + psi <= 1, + msg="0 <= psi <= 1", + ) class HurdlePoisson: @@ -864,14 +899,20 @@ class HurdlePoisson: """ def __new__(cls, name, psi, mu, **kwargs): - return _hurdle_mixture( - name=name, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), dtype="int", **kwargs + return _Hurdle._create( + name=name, + nonzero_p=psi, + nonzero_dist=Poisson.dist(mu=mu), + **kwargs, ) @classmethod def dist(cls, psi, mu, **kwargs): - return _hurdle_mixture( - name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), dtype="int", **kwargs + return _Hurdle._create( + name=None, + nonzero_p=psi, + nonzero_dist=Poisson.dist(mu=mu), + **kwargs, ) @@ -914,21 +955,19 @@ class HurdleNegativeBinomial: """ def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=name, nonzero_p=psi, nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n), - dtype="int", **kwargs, ) @classmethod def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=None, nonzero_p=psi, nonzero_dist=NegativeBinomial.dist(mu=mu, alpha=alpha, p=p, n=n), - dtype="int", **kwargs, ) @@ -963,24 +1002,28 @@ class HurdleGamma: Alternative shape parameter (mu > 0). sigma : tensor_like of float, optional Alternative scale parameter (sigma > 0). + + .. warning:: + HurdleGamma distributions cannot be sampled correctly with MCMC methods, + as this would require a specialized step sampler. They are intended to be used as + observed variables, and/or sampled exclusively with forward methods like + `sample_prior_predictive` and `sample_posterior_predictive`. """ def __new__(cls, name, psi, alpha=None, beta=None, mu=None, sigma=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=name, nonzero_p=psi, nonzero_dist=Gamma.dist(alpha=alpha, beta=beta, mu=mu, sigma=sigma), - dtype="float", **kwargs, ) @classmethod def dist(cls, psi, alpha=None, beta=None, mu=None, sigma=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=None, nonzero_p=psi, nonzero_dist=Gamma.dist(alpha=alpha, beta=beta, mu=mu, sigma=sigma), - dtype="float", **kwargs, ) @@ -1015,23 +1058,27 @@ class HurdleLogNormal: tau : tensor_like of float, optional Scale parameter (tau > 0). (only required if sigma is not specified). Defaults to 1. + + .. warning:: + HurdleLogNormal distributions cannot be sampled correctly with MCMC methods, + as this would require a specialized step sampler. They are intended to be used as + observed variables, and/or sampled exclusively with forward methods like + `sample_prior_predictive` and `sample_posterior_predictive`. """ def __new__(cls, name, psi, mu=0, sigma=None, tau=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=name, nonzero_p=psi, nonzero_dist=LogNormal.dist(mu=mu, sigma=sigma, tau=tau), - dtype="float", **kwargs, ) @classmethod def dist(cls, psi, mu=0, sigma=None, tau=None, **kwargs): - return _hurdle_mixture( + return _Hurdle._create( name=None, nonzero_p=psi, nonzero_dist=LogNormal.dist(mu=mu, sigma=sigma, tau=tau), - dtype="float", **kwargs, ) diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py index f183ace5db..0e3129935e 100644 --- a/pymc/distributions/moments/means.py +++ b/pymc/distributions/moments/means.py @@ -70,7 +70,7 @@ ) from pymc.distributions.discrete import DiscreteUniformRV from pymc.distributions.distribution import DiracDeltaRV -from pymc.distributions.mixture import MarginalMixtureRV +from pymc.distributions.mixture import MixtureRV from pymc.distributions.multivariate import ( CARRV, DirichletMultinomialRV, @@ -300,7 +300,7 @@ def lognormal_mean(op, rv, rng, size, mu, sigma): return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size) -@_mean.register(MarginalMixtureRV) +@_mean.register(MixtureRV) def marginal_mixture_mean(op, rv, rng, weights, *components): ndim_supp = components[0].owner.op.ndim_supp weights = pt.shape_padright(weights, ndim_supp) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 6f54aba2d1..8ef378b76b 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -43,7 +43,6 @@ from pymc.exceptions import ShapeError from pymc.pytensorf import PotentialShapeType -from pymc.util import _add_future_warning_tag def to_tuple(shape): @@ -264,7 +263,6 @@ def change_dist_size( op = dist.owner.op new_dist = _change_dist_size(op, dist, new_size=new_size, expand=expand) - _add_future_warning_tag(new_dist) new_dist.name = dist.name for k, v in dist.tag.__dict__.items(): diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 87347add1e..f18e1f4c25 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -286,13 +286,6 @@ class GaussianRandomWalk(PredefinedRandomWalk): @classmethod def get_dists(cls, mu=0.0, sigma=1.0, *, init_dist=None, **kwargs): - if "init" in kwargs: - warnings.warn( - "init parameter is now called init_dist. Using init will raise an error in a future release.", - FutureWarning, - ) - init_dist = kwargs.pop("init") - if init_dist is None: warnings.warn( "Initial distribution not specified, defaulting to `Normal.dist(0, 100)`." @@ -340,14 +333,6 @@ class MvGaussianRandomWalk(PredefinedRandomWalk): @classmethod def get_dists(cls, mu, *, cov=None, tau=None, chol=None, lower=True, init_dist=None, **kwargs): - if "init" in kwargs: - warnings.warn( - "init parameter is now called init_dist. Using init will raise an error " - "in a future release.", - FutureWarning, - ) - init_dist = kwargs.pop("init") - if init_dist is None: warnings.warn( "Initial distribution not specified, defaulting to `MvNormal.dist(0, I*100)`." @@ -396,14 +381,6 @@ class MvStudentTRandomWalk(PredefinedRandomWalk): def get_dists( cls, *, nu, mu, scale=None, tau=None, chol=None, lower=True, init_dist=None, **kwargs ): - if "init" in kwargs: - warnings.warn( - "init parameter is now called init_dist. Using init will raise an error " - "in a future release.", - FutureWarning, - ) - init_dist = kwargs.pop("init") - if init_dist is None: warnings.warn( "Initial distribution not specified, defaulting to `MvNormal.dist(0, I*100)`." @@ -588,13 +565,6 @@ def dist( sigma = pt.as_tensor_variable(sigma) rhos = pt.atleast_1d(pt.as_tensor_variable(rho)) - if "init" in kwargs: - warnings.warn( - "init parameter is now called init_dist. Using init will raise an error in a future release.", - FutureWarning, - ) - init_dist = kwargs.pop("init") - ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) steps = get_support_shape_1d( support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=ar_order diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index ebdaf3c3e1..b1e02fd1f9 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from functools import singledispatch @@ -48,22 +47,6 @@ ] -def __getattr__(name): - if name in ("univariate_ordered", "multivariate_ordered"): - warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning) - return ordered - - if name in ("univariate_sum_to_1", "multivariate_sum_to_1"): - warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) - return sum_to_1 - - if name == "RVTransform": - warnings.warn("RVTransform has been renamed to Transform", FutureWarning) - return Transform - - raise AttributeError(f"module {__name__} has no attribute {name}") - - @singledispatch def _default_transform(op: Op, rv: TensorVariable): """Return default transform for a given Distribution `Op`.""" @@ -100,9 +83,7 @@ class Ordered(Transform): name = "ordered" - def __init__(self, ndim_supp=None, positive=False, ascending=True): - if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + def __init__(self, positive=False, ascending=True): self.positive = positive self.ascending = ascending @@ -142,10 +123,6 @@ class SumTo1(Transform): name = "sumto1" - def __init__(self, ndim_supp=None): - if ndim_supp is not None: - warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) - def backward(self, value, *inputs): remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) return pt.concatenate([value[..., :], remaining], axis=-1) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index c276a5c496..df9419c744 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -240,12 +240,6 @@ def make_initial_point_expression( strategy = default_strategy if isinstance(strategy, str): - if strategy == "moment": - strategy = "support_point" - warnings.warn( - "The 'moment' strategy is deprecated. Use 'support_point' instead.", - FutureWarning, - ) if strategy == "support_point": try: value = support_point(variable) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 5c7f28e661..4b8808a3bd 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -48,17 +48,6 @@ from pytensor.tensor.random.op import RandomVariable -def __getattr__(name): - if name == "MeasurableVariable": - warnings.warn( - f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.", - FutureWarning, - ) - return MeasurableOp - - raise AttributeError(f"module {__name__} has no attribute {name}") - - @singledispatch def _logprob( op: Op, @@ -236,13 +225,16 @@ class ValuedRV(Op): and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points. """ + view_map = {0: [0]} + def make_node(self, rv, value): assert isinstance(rv, Variable) assert isinstance(value, Variable) return Apply(self, [rv, value], [rv.type(name=rv.name)]) def perform(self, node, inputs, out): - raise NotImplementedError("ValuedVar should not be present in the final graph!") + warnings.warn("ValuedVar should not be present in the final graph!") + out[0][0] = inputs[0] def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 6fd4a5489e..9a856fc5c2 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -98,23 +98,7 @@ def _warn_rvs_in_inferred_graph(graph: TensorVariable | Sequence[TensorVariable] ) -def _deprecate_warn_missing_rvs(warn_rvs, kwargs): - if "warn_missing_rvs" in kwargs: - warnings.warn( - "Argument `warn_missing_rvs` was renamed to `warn_rvs` and will be removed in a future release", - FutureWarning, - ) - if warn_rvs is None: - warn_rvs = kwargs.pop("warn_missing_rvs") - else: - raise ValueError("Can't set both warn_rvs and warn_missing_rvs") - else: - if warn_rvs is None: - warn_rvs = True - return warn_rvs, kwargs - - -def logp(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable: +def logp(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable: """Create a graph for the log-probability of a random variable. Parameters @@ -200,8 +184,6 @@ def normal_logp(value, mu, sigma): pm.CustomDist("x", mu, sigma, logp=normal_logp) """ - warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) - value = pt.as_tensor_variable(value, dtype=rv.dtype) try: return _logprob_helper(rv, value, **kwargs) @@ -216,7 +198,7 @@ def normal_logp(value, mu, sigma): return expr -def logcdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable: +def logcdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable: """Create a graph for the log-CDF of a random variable. Parameters @@ -301,7 +283,6 @@ def normal_logcdf(value, mu, sigma): pm.CustomDist("x", mu, sigma, logcdf=normal_logcdf) """ - warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) value = pt.as_tensor_variable(value, dtype=rv.dtype) try: return _logcdf_helper(rv, value, **kwargs) @@ -317,7 +298,7 @@ def normal_logcdf(value, mu, sigma): return expr -def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> TensorVariable: +def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable: """Create a graph for the inverse CDF of a random variable. Parameters @@ -384,7 +365,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens print(exp_rv_icdf_fn(value=0.9, mu=0.0)) # 3.60222448 """ - warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) value = pt.as_tensor_variable(value, dtype="floatX") try: return _icdf_helper(rv, value, **kwargs) @@ -400,16 +380,9 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens return expr -RVS_IN_JOINT_LOGP_GRAPH_MSG = ( - "Random variables detected in the logp graph: %s.\n" - "This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n" - "or when not all rvs have a corresponding value variable." -) - - def conditional_logp( rv_values: dict[TensorVariable, TensorVariable], - warn_rvs=None, + warn_rvs=True, ir_rewriter: GraphRewriter | None = None, extra_rewrites: GraphRewriter | NodeRewriter | None = None, **kwargs, @@ -474,8 +447,6 @@ def conditional_logp( from the respective `RandomVariable`. """ - warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) - fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) if extra_rewrites is not None: @@ -563,7 +534,11 @@ def conditional_logp( if warn_rvs: rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) if rvs_in_logp_expressions: - warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning) + warnings.warn( + f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n" + "This can happen when not all random variables have a corresponding value variable.", + UserWarning, + ) return values_to_logprobs @@ -611,24 +586,10 @@ def transformed_conditional_logp( rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list) if rvs_in_logp_expressions: - raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions) + raise ValueError( + f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n" + "This can happen when mixing variables from different models, " + "or when CustomDist logp or Interval transform functions reference nonlocal variables." + ) return logp_terms_list - - -def factorized_joint_logprob(*args, **kwargs): - warnings.warn( - "`factorized_joint_logprob` was renamed to `conditional_logp`. " - "The function will be removed in a future release", - FutureWarning, - ) - return conditional_logp(*args, **kwargs) - - -def joint_logp(*args, **kwargs): - warnings.warn( - "`joint_logp` was renamed to `transformed_conditional_logp`. " - "The function will be removed in a future release", - FutureWarning, - ) - return transformed_conditional_logp(*args, **kwargs) diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 4a28d5cd4a..2e1b96d343 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Sequence @@ -40,7 +41,8 @@ def make_node(self, tran_value: TensorVariable, value: TensorVariable): return Apply(self, [tran_value, value], [tran_value.type()]) def perform(self, node, inputs, outputs): - raise NotImplementedError("These `Op`s should be removed from graphs used for computation.") + warnings.warn("TransformedValue should not be present in the final graph!") + outputs[0][0] = inputs[0] def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] diff --git a/pymc/math.py b/pymc/math.py index 1845dd5111..13655f5345 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -181,6 +181,8 @@ "zeros_like", ] +from pymc.util import UNSET + def kronecker(*Ks): r"""Return the Kronecker product of arguments. @@ -279,16 +281,6 @@ def logdiffexp(a, b): return a + pt.log1mexp(b - a) -def logdiffexp_numpy(a, b): - """Return log(exp(a) - exp(b)).""" - warnings.warn( - "pymc.math.logdiffexp_numpy is being deprecated.", - FutureWarning, - stacklevel=2, - ) - return a + log1mexp_numpy(b - a, negative_input=True) - - invlogit = sigmoid @@ -302,7 +294,7 @@ def logit(p): return pt.log(p / (floatX(1) - p)) -def log1mexp(x, *, negative_input=False): +def log1mexp(x, *, negative_input=UNSET): r"""Return log(1 - exp(-x)). This function is numerically more stable than the naive approach. @@ -316,50 +308,20 @@ def log1mexp(x, *, negative_input=False): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - if not negative_input: - warnings.warn( - "pymc.math.log1mexp will expect a negative input in a future " - "version of PyMC.\n To suppress this warning set `negative_input=True`", - FutureWarning, - stacklevel=2, - ) - x = -x + if negative_input is not UNSET: + if not negative_input: + raise ValueError( + "log1mexp with negative_input=False is no longer supported. Negate the input yourself before calling the function." + ) + else: + warnings.warn( + "log1mexp with negative_input=True is now the default behavior. Specifying will fail in a future release of PyMC. Simply omit it", + FutureWarning, + ) return pt.log1mexp(x) -def log1mexp_numpy(x, *, negative_input=False): - """Return log(1 - exp(x)). - - This function is numerically more stable than the naive approach. - - For details, see - https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf - """ - warnings.warn( - "pymc.math.log1mexp_numpy is being deprecated.", - FutureWarning, - stacklevel=2, - ) - x = np.asarray(x, dtype="float") - - if not negative_input: - warnings.warn( - "pymc.math.log1mexp_numpy will expect a negative input in a future " - "version of PyMC.\n To suppress this warning set `negative_input=True`", - FutureWarning, - stacklevel=2, - ) - x = -x - - out = np.empty_like(x) - mask = x < -0.6931471805599453 # log(1/2) - out[mask] = np.log1p(-np.exp(x[mask])) - mask = ~mask - out[mask] = np.log(-np.expm1(x[mask])) - return out - - def flatten_list(tensors): return pt.concatenate([var.ravel() for var in tensors]) diff --git a/pymc/model/core.py b/pymc/model/core.py index d040114075..66e633e15e 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -52,7 +52,6 @@ from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values -from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -68,7 +67,6 @@ UNSET, VarName, WithMemoization, - _add_future_warning_tag, _UnsetType, get_transformed_name, get_value_vars_from_user_vars, @@ -438,6 +436,13 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: """Exit the context manager.""" _ = MODEL_MANAGER.active_contexts.pop() + def _display_(self): + import marimo as mo + + from pymc.model_graph import model_to_mermaid + + return mo.mermaid(model_to_mermaid(self)) + @staticmethod def _validate_name(name): if name.endswith(":"): @@ -450,19 +455,12 @@ def __init__( coords=None, check_bounds=True, *, - coords_mutable=None, model: _UnsetType | None | Model = UNSET, ): self.name = self._validate_name(name) self.check_bounds = check_bounds self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context - if coords_mutable is not None: - warnings.warn( - "All coords are now mutable by default. coords_mutable will be removed in a future release.", - FutureWarning, - ) - if self.parent is not None: self.named_vars = treedict(parent=self.parent.named_vars) self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims) @@ -492,9 +490,6 @@ def __init__( self._coords = {} self._dim_lengths = {} self.add_coords(coords) - if coords_mutable is not None: - for name, values in coords_mutable.items(): - self.add_coord(name, values, mutable=True) from pymc.printing import str_for_model @@ -921,7 +916,6 @@ def add_coord( self, name: str, values: Sequence | np.ndarray | None = None, - mutable: bool | None = None, *, length: int | Variable | None = None, ): @@ -935,19 +929,10 @@ def add_coord( values : optional, array_like Coordinate values or ``None`` (for auto-numbering). If ``None`` is passed, a ``length`` must be specified. - mutable : bool - Whether the created dimension should be resizable. - Default is False. length : optional, scalar A scalar of the dimensions length. Defaults to ``pytensor.tensor.constant(len(values))``. """ - if mutable is not None: - warnings.warn( - "Coords are now always mutable. Specifying `mutable` will raise an error in a future release", - FutureWarning, - ) - if name in {"draw", "chain", "__sample__"}: raise ValueError( "Dimensions can not be named `draw`, `chain` or `__sample__`, " @@ -964,6 +949,7 @@ def add_coord( if name in self.coords: if not np.array_equal(values, self.coords[name]): raise ValueError(f"Duplicate and incompatible coordinate: {name}.") + return if length is not None and not isinstance(length, int | Variable): raise ValueError( f"The `length` passed for the '{name}' coord must be an int, PyTensor Variable or None." @@ -1222,7 +1208,6 @@ def register_rv( """ name = self.name_for(name) rv_var.name = name - _add_future_warning_tag(rv_var) # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1434,9 +1419,6 @@ def create_value_var( rv_var, *rv_var.owner.inputs ).tag.test_value - _add_future_warning_tag(value_var) - rv_var.tag.value_var = value_var - self.rvs_to_transforms[rv_var] = transform self.rvs_to_values[rv_var] = value_var self.values_to_rvs[value_var] = rv_var @@ -1700,23 +1682,6 @@ def profile( return f.profile - def update_start_vals(self, a: dict[str, np.ndarray], b: dict[str, np.ndarray]): - r"""Update point `a` with `b`, without overwriting existing keys. - - Values specified for transformed variables in `a` will be recomputed - conditional on the values of `b` and stored in `b`. - - Parameters - ---------- - a : dict - - b : dict - """ - raise FutureWarning( - "The `Model.update_start_vals` method was removed." - " To change initial values you may set the items of `Model.initial_values` directly." - ) - def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]: """Evaluate shapes of untransformed AND transformed free variables. @@ -2038,6 +2003,8 @@ def to_graphviz( # creates the file `schools.pdf` schools.to_graphviz().render("schools") """ + from pymc.model_graph import model_to_graphviz + return model_to_graphviz( model=self, var_names=var_names, diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 3a241948af..50fd5227d6 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,16 +21,11 @@ from typing import Any, cast from pytensor import function -from pytensor.graph import Apply from pytensor.graph.basic import ancestors, walk -from pytensor.scalar.basic import Cast -from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.shape import Shape from pytensor.tensor.variable import TensorVariable -import pymc as pm - +from pymc.model.core import modelcontext from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( @@ -241,42 +236,32 @@ class ModelGraph: def __init__(self, model): self.model = model self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False) + self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() def get_parent_names(self, var: TensorVariable) -> set[VarName]: - if var.owner is None or var.owner.inputs is None: + if var.owner is None: return set() - def _filter_non_parameter_inputs(var): - node = var.owner - if isinstance(node.op, Shape): - # Don't show shape-related dependencies - return [] - if isinstance(node.op, RandomVariable): - # Filter out rng and size parameters or RandomVariable nodes - return node.op.dist_params(node) - else: - # Otherwise return all inputs - return node.inputs - - blockers = set(self.model.named_vars) + named_vars = self._all_vars def _expand(x): - nonlocal blockers - if x.name in blockers: + if x in named_vars: + # Don't go beyond named_vars return [x] - if isinstance(x.owner, Apply): - return reversed(_filter_non_parameter_inputs(x)) - return [] - - parents = set() - for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand): - # Only consider nodes that are in the named model variables. - vname = getattr(x, "name", None) - if isinstance(vname, str) and vname in self._all_var_names: - parents.add(VarName(vname)) + if x.owner is None: + return [] + if isinstance(x.owner.op, Shape): + # Don't propagate shape-related dependencies + return [] + # Continue walking the graph through the inputs + return x.owner.inputs - return parents + return { + cast(VarName, ancestor.name) # type: ignore[union-attr] + for ancestor in walk(nodes=var.owner.inputs, expand=_expand) + if ancestor in named_vars + } def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: if var_names is None: @@ -312,35 +297,28 @@ def make_compute_graph( self, var_names: Iterable[VarName] | None = None ) -> dict[VarName, set[VarName]]: """Get map of var_name -> set(input var names) for the model.""" + model = self.model + named_vars = self._all_vars input_map: dict[VarName, set[VarName]] = defaultdict(set) - for var_name in self.vars_to_plot(var_names): - var = self.model[var_name] - parent_name = self.get_parent_names(var) - input_map[var_name] = input_map[var_name].union(parent_name) - - if var in self.model.observed_RVs: - obs_node = self.model.rvs_to_values[var] - - # loop created so that the elif block can go through this again - # and remove any intermediate ops, notably dtype casting, to observations - while True: - obs_name = obs_node.name - if obs_name and obs_name != var_name: - input_map[var_name] = input_map[var_name].difference({obs_name}) - input_map[obs_name] = input_map[obs_name].union({var_name}) - break - elif ( - # for cases where observations are cast to a certain dtype - # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795 - obs_node.owner - and isinstance(obs_node.owner.op, Elemwise) - and isinstance(obs_node.owner.op.scalar_op, Cast) - ): - # we can retrieve the observation node by going up the graph - obs_node = obs_node.owner.inputs[0] - else: - break + var_names_to_plot = self.vars_to_plot(var_names) + for var_name in var_names_to_plot: + parent_names = self.get_parent_names(model[var_name]) + input_map[var_name].update(parent_names) + + for var_name in var_names_to_plot: + if (var := model[var_name]) in model.observed_RVs: + # Make observed `Data` variables flow from the observed RV, and not the other way around + # (In the generative graph they usually inform shape of the observed RV) + # We have to iterate over the ancestors of the observed values because there can be + # deterministic operations in between the `Data` variable and the observed value. + obs_var = model.rvs_to_values[var] + for ancestor in ancestors([obs_var]): + if ancestor not in named_vars: + continue + obs_name = cast(VarName, ancestor.name) + input_map[var_name].discard(obs_name) + input_map[obs_name].add(var_name) return input_map @@ -361,7 +339,7 @@ def get_plates( plates = defaultdict(set) # TODO: Evaluate all RV shapes at once - # This should help find discrepencies, and + # This should help find discrepancies, and # avoids unnecessary function compiles for determining labels. dim_lengths: dict[str, int] = { dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() @@ -429,6 +407,14 @@ def edges( for parent in parents ] + def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]: + """Get all nodes in the model graph.""" + plates = plates or self.get_plates() + nodes = [] + for plate in plates: + nodes.extend(plate.variables) + return nodes + def make_graph( name: str, @@ -654,7 +640,7 @@ def model_to_networkx( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_networkx( name=model.name, @@ -769,7 +755,7 @@ def model_to_graphviz( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_graph( model.name, @@ -785,3 +771,137 @@ def model_to_graphviz( if include_dim_lengths else create_plate_label_without_dim_length, ) + + +def _create_mermaid_node_name(name: str) -> str: + return name.replace(":", "_").replace(" ", "_") + + +def _build_mermaid_node(node: NodeInfo) -> list[str]: + var = node.var + node_type = node.node_type + name = cast(str, var.name) + node_name = _create_mermaid_node_name(name) + if node_type == NodeType.DATA: + return [ + f"{node_name}[{var.name} ~ Data]", + f"{node_name}@{{ shape: db }}", + ] + elif node_type == NodeType.OBSERVED_RV: + return [ + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", + f"style {node_name} fill:#757575", + ] + + elif node_type == NodeType.FREE_RV: + return [ + f"{node_name}([{name} ~ {random_variable_symbol(var)}])", + f"{node_name}@{{ shape: rounded }}", + ] + elif node_type == NodeType.DETERMINISTIC: + return [ + f"{node_name}([{name} ~ Deterministic])", + f"{node_name}@{{ shape: rect }}", + ] + elif node_type == NodeType.POTENTIAL: + return [ + f"{node_name}([{name} ~ Potential])", + f"{node_name}@{{ shape: diam }}", + f"style {node_name} fill:#f0f0f0", + ] + + return [] + + +def _build_mermaid_nodes(nodes) -> list[str]: + node_lines = [] + for node in nodes: + node_lines.extend(_build_mermaid_node(node)) + + return node_lines + + +def _build_mermaid_edges(edges) -> list[str]: + """Return a list of Mermaid edge definitions.""" + edge_lines = [] + for child, parent in edges: + child_id = _create_mermaid_node_name(child) + parent_id = _create_mermaid_node_name(parent) + edge_lines.append(f"{parent_id} --> {child_id}") + return edge_lines + + +def _build_mermaid_plates(plates, include_dim_lengths) -> list[str]: + plate_lines = [] + for plate in plates: + if not plate.dim_info: + continue + + plate_label_func = ( + create_plate_label_with_dim_length + if include_dim_lengths + else create_plate_label_without_dim_length + ) + plate_label = plate_label_func(plate.dim_info) + plate_name = f'subgraph "{plate_label}"' + plate_lines.append(plate_name) + for var in plate.variables: + plate_lines.append(f" {var.var.name}") + plate_lines.append("end") + + return plate_lines + + +def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = True) -> str: + """Produce a Mermaid diagram string from a PyMC model. + + Parameters + ---------- + model : pm.Model + The model to plot. Not required when called from inside a modelcontext. + var_names : iterable of variable names, optional + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph + include_dim_lengths : bool + Include the dim lengths in the plate label. Default is True. + + Returns + ------- + str + Mermaid diagram string representing the model graph. + + Examples + -------- + Visualize a simple PyMC model + + .. code-block:: python + + import pymc as pm + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + sigma = pm.HalfNormal("sigma", sigma=1) + + pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3]) + + print(pm.model_to_mermaid(model)) + + + """ + model = modelcontext(model) + graph = ModelGraph(model) + plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info)) + edges = sorted(graph.edges(var_names=var_names)) + nodes = sorted(graph.nodes(plates=plates), key=lambda node: cast(str, node.var.name)) + + return "\n".join( + [ + "graph TD", + "%% Nodes:", + *_build_mermaid_nodes(nodes), + "\n%% Edges:", + *_build_mermaid_edges(edges), + "\n%% Plates:", + *_build_mermaid_plates(plates, include_dim_lengths=include_dim_lengths), + ] + ) diff --git a/pymc/plots/__init__.py b/pymc/plots/__init__.py index 49068d5369..f1c63ce581 100644 --- a/pymc/plots/__init__.py +++ b/pymc/plots/__init__.py @@ -19,9 +19,7 @@ See https://arviz-devs.github.io/arviz/ for details on plots. """ -import functools import sys -import warnings import arviz as az @@ -32,39 +30,4 @@ setattr(sys.modules[__name__], attr, obj) -def alias_deprecation(func, alias: str): - original = func.__name__ - - @functools.wraps(func) - def wrapped(*args, **kwargs): - raise FutureWarning( - f"The function `{alias}` from PyMC was an alias for `{original}` from ArviZ. " - "It was removed in PyMC 4.0. " - f"Switch to `pymc.{original}` or `arviz.{original}`." - ) - - return wrapped - - -# Aliases of ArviZ functions -autocorrplot = alias_deprecation(az.plot_autocorr, alias="autocorrplot") -forestplot = alias_deprecation(az.plot_forest, alias="forestplot") -kdeplot = alias_deprecation(az.plot_kde, alias="kdeplot") -energyplot = alias_deprecation(az.plot_energy, alias="energyplot") -densityplot = alias_deprecation(az.plot_density, alias="densityplot") -pairplot = alias_deprecation(az.plot_pair, alias="pairplot") -traceplot = alias_deprecation(az.plot_trace, alias="traceplot") -compareplot = alias_deprecation(az.plot_compare, alias="compareplot") - - -__all__ = ( - *az.plots.__all__, - "autocorrplot", - "compareplot", - "forestplot", - "kdeplot", - "traceplot", - "energyplot", - "densityplot", - "pairplot", -) +__all__ = az.plots.__all__ diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py new file mode 100644 index 0000000000..7299584307 --- /dev/null +++ b/pymc/progress_bar.py @@ -0,0 +1,425 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import TYPE_CHECKING, Literal + +from rich.box import SIMPLE_HEAD +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.style import Style +from rich.table import Column, Table +from rich.theme import Theme + +if TYPE_CHECKING: + from pymc.step_methods.compound import BlockedStep, CompoundStep + +ProgressBarType = Literal[ + "combined", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", +] +default_progress_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", + } +) + + +class CustomProgress(Progress): + """A child of Progress that allows to disable progress bars and its container. + + The implementation simply checks an `is_enabled` flag and generates the progress bar only if + it's `True`. + """ + + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + + if self.is_enabled: + super().__init__(*args, **kwargs) + + def __enter__(self): + """Enter the context manager.""" + if self.is_enabled: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + if self.is_enabled: + super().__exit__(exc_type, exc_val, exc_tb) + + def add_task(self, *args, **kwargs): + if self.is_enabled: + return super().add_task(*args, **kwargs) + return None + + def advance(self, task_id, advance=1) -> None: + if self.is_enabled: + super().advance(task_id, advance) + return None + + def update( + self, + task_id, + *, + total=None, + completed=None, + advance=None, + description=None, + visible=None, + refresh=False, + **fields, + ): + if self.is_enabled: + super().update( + task_id, + total=total, + completed=completed, + advance=advance, + description=description, + visible=visible, + refresh=refresh, + **fields, + ) + return None + + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class RecolorOnFailureBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a failure.""" + + def __init__(self, *args, failing_color="red", **kwargs): + from matplotlib.colors import to_rgb + + self.failing_color = failing_color + self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)] + + super().__init__(*args, **kwargs) + + self.default_complete_style = self.complete_style + self.default_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + if task.fields["failing"]: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) + else: + # Recovered from failing yay + self.complete_style = self.default_complete_style + self.finished_style = self.default_finished_style + + +class ProgressBarManager: + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: "BlockedStep | CompoundStep", + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_progressbar_update_functions`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_progressbar_update_functions` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep or CompoundStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + + match progressbar: + case True: + self.combined_progress = False + self.full_stats = True + show_progress = True + case False: + self.combined_progress = False + self.full_stats = True + show_progress = False + case "combined": + self.combined_progress = True + self.full_stats = False + show_progress = True + case "split": + self.combined_progress = False + self.full_stats = False + show_progress = True + case "combined+stats" | "stats+combined": + self.combined_progress = True + self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": + self.combined_progress = False + self.full_stats = True + show_progress = True + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "one of 'combined', 'split', 'split+stats', or 'combined+stats." + ) + + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + self.progress_stats = progress_stats + self.update_stats_functions = step_method._make_progressbar_update_functions() + + self._show_progress = show_progress + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + failing=False, + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + failing=False, + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + @staticmethod + def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + elapsed = self._progress.tasks[chain_idx].elapsed + speed, unit = self.compute_draw_speed(elapsed, draw) + + failing = False + all_step_stats = {} + + chain_progress_stats = [ + update_stats_fn(step_stats) + for update_stats_fn, step_stats in zip(self.update_stats_functions, stats, strict=True) + ] + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key == "failing": + failing |= val + continue + if not self.full_stats: + # Only care about the "failing" flag + continue + + if key in all_step_stats: + # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now + continue + else: + all_step_stats[key] = val + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + failing=failing, + **all_step_stats, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw, + failing=failing, + **all_step_stats, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + RecolorOnFailureBarColumn( + table_column=Column("Progress", ratio=2), + failing_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 78eb3f7bbc..f1d69c9282 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from collections.abc import Callable, Generator, Iterable, Sequence +from collections.abc import Iterable, Sequence from typing import cast import numpy as np @@ -33,7 +33,6 @@ clone_get_equiv, equal_computations, graph_inputs, - walk, ) from pytensor.graph.fg import FunctionGraph, Output from pytensor.scalar.basic import Cast @@ -55,11 +54,9 @@ PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable - __all__ = [ "CallableTensor", "compile", - "compile_pymc", "cont_inputs", "convert_data", "convert_observed_data", @@ -173,38 +170,6 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: raise TypeError(f"Data cannot be extracted from {x}") -def walk_model( - graphs: Iterable[TensorVariable], - stop_at_vars: set[TensorVariable] | None = None, - expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [], -) -> Generator[TensorVariable, None, None]: - """Walk model graphs and yield their nodes. - - Parameters - ---------- - graphs - The graphs to walk. - stop_at_vars - A list of variables at which the walk will terminate. - expand_fn - A function that returns the next variable(s) to be traversed. - """ - warnings.warn("walk_model will be removed in a future relase of PyMC", FutureWarning) - - if stop_at_vars is None: - stop_at_vars = set() - - def expand(var): - new_vars = expand_fn(var) - - if var.owner and var not in stop_at_vars: - new_vars.extend(reversed(var.owner.inputs)) - - return new_vars - - yield from walk(graphs, expand, bfs=False) - - def replace_vars_in_graphs( graphs: Iterable[Variable], replacements: dict[Variable, Variable], @@ -594,12 +559,19 @@ def join_nonshared_inputs( class PointFunc: """Wraps so a function so it takes a dict of arguments instead of arguments.""" + __slots__ = ("f",) + def __init__(self, f): self.f = f def __call__(self, state): return self.f(**state) + def __getattr__(self, item): + """Allow access to the original function attributes.""" + # This is only reached if `__getattribute__` fails. + return getattr(self.f, item) + class CallableTensor: """Turns a symbolic variable with one input into a function that returns symbolic arguments with the one variable replaced with the input.""" @@ -954,14 +926,6 @@ def compile( return pytensor_function -def compile_pymc(*args, **kwargs): - warnings.warn( - "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC", - FutureWarning, - ) - return compile(*args, **kwargs) - - def constant_fold( xs: Sequence[TensorVariable], raise_not_constant: bool = True ) -> tuple[np.ndarray | Variable, ...]: diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b1f9c39895..d65c6c0118 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -28,9 +28,11 @@ import numpy as np import xarray +import xarray as xr from arviz import InferenceData from pytensor import tensor as pt +from pytensor.graph import vectorize_graph from pytensor.graph.basic import ( Apply, Constant, @@ -42,7 +44,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable from rich.console import Console from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme @@ -52,13 +54,14 @@ from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list from pymc.backends.base import MultiTrace from pymc.blocking import PointType +from pymc.distributions.shape_utils import change_dist_size +from pymc.logprob.utils import rvs_in_graph from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import compile from pymc.util import ( - CustomProgress, RandomState, _get_seeds_per_chain, - default_progress_theme, get_default_varnames, point_wrapper, ) @@ -68,6 +71,7 @@ "draw", "sample_posterior_predictive", "sample_prior_predictive", + "vectorize_over_posterior", ) ArrayLike: TypeAlias = np.ndarray | list[float] @@ -984,3 +988,99 @@ def sample_posterior_predictive( idata.extend(idata_pp) return idata return idata_pp + + +def vectorize_over_posterior( + outputs: list[Variable], + posterior: xr.Dataset, + input_rvs: list[Variable], + allow_rvs_in_graph: bool = True, + sample_dims: tuple[str, ...] = ("chain", "draw"), +) -> list[Variable]: + """Vectorize outputs over posterior samples of subset of input rvs. + + This function creates a new graph for the supplied outputs, where the required + subset of input rvs are replaced by their posterior samples (chain and draw + dimensions are flattened). The other input tensors are kept as is. + + Parameters + ---------- + outputs : list[Variable] + The list of variables to vectorize over the posterior samples. + posterior : xr.Dataset + The posterior samples to use as replacements for the `input_rvs`. + input_rvs : list[Variable] + The list of random variables to replace with their posterior samples. + allow_rvs_in_graph : bool + Whether to allow random variables to be present in the graph. If False, + an error will be raised if any random variables are found in the graph. If + True, the remaining random variables will be resized to match the number of + draws from the posterior. + sample_dims : tuple[str, ...] + The dimensions of the posterior samples to use for vectorizing the `input_rvs`. + + + Returns + ------- + vectorized_outputs : list[Variable] + The vectorized variables + + Raises + ------ + RuntimeError + If random variables are found in the graph and `allow_rvs_in_graph` is False + """ + # Identify which free RVs are needed to compute `outputs` + needed_rvs: list[TensorVariable] = [ + cast(TensorVariable, rv) + for rv in ancestors(outputs, blockers=input_rvs) + if rv in set(input_rvs) + ] + + # Replace needed_rvs with actual posterior samples + batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims]) + replace_dict: dict[Variable, Variable] = {} + for rv in needed_rvs: + posterior_samples = posterior[rv.name].data + + replace_dict[rv] = pt.constant(posterior_samples.astype(rv.dtype), name=rv.name) + + # Replace the rvs that remain in the graph with resized versions + all_rvs = rvs_in_graph(outputs) + + # Once we give values for the needed_rvs (setting them to their posterior samples), + # we need to identify the rvs that only depend on these conditioned values, and + # don't depend on any other rvs or output nodes. + # These variables need to be resized because they won't be resized implicitly by + # the replacement of the needed_rvs or other random variables in the graph when we + # later call vectorize_graph. + independent_rvs: list[TensorVariable] = [] + for rv in [ + rv + for rv in general_toposort( # type: ignore[call-overload] + all_rvs, lambda x: x.owner.inputs if x.owner is not None else None + ) + if rv in all_rvs + ]: + rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs]) + if ( + rv not in needed_rvs + and not ({*outputs, *independent_rvs} & set(rv_ancestors)) + and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs} + ): + independent_rvs.append(rv) + for rv in independent_rvs: + replace_dict[rv] = change_dist_size(rv, new_size=batch_shape, expand=True) + + # Vectorize across samples + vectorized_outputs = list(vectorize_graph(outputs, replace=replace_dict)) + for vectorized_output, output in zip(vectorized_outputs, outputs): + vectorized_output.name = output.name + + if not allow_rvs_in_graph: + remaining_rvs = rvs_in_graph(vectorized_outputs) + if remaining_rvs: + raise RuntimeError( + f"The following random variables found in the extracted graph: {remaining_rvs}" + ) + return vectorized_outputs diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..542797caa8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -54,6 +54,7 @@ from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext +from pymc.progress_bar import ProgressBarManager, ProgressBarType, default_progress_theme from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population from pymc.stats.convergence import ( @@ -65,12 +66,9 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - ProgressBarManager, - ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, - default_progress_theme, drop_warning_stat, get_random_generator, get_untransformed_name, @@ -331,11 +329,7 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) - if var_names is not None: - warnings.warn( - "`var_names` are currently ignored by the nutpie sampler", - UserWarning, - ) + compile_kwargs = {} nuts_sampler_kwargs = nuts_sampler_kwargs.copy() for kwarg in ("backend", "gradient_backend"): @@ -343,6 +337,7 @@ def _sample_external_nuts( compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) compiled_model = nutpie.compile_pymc_model( model, + var_names=var_names, **compile_kwargs, ) t_start = time.time() @@ -755,11 +750,9 @@ def joined_blas_limiter(): ) if random_seed == -1: - warnings.warn( - "Setting random_seed = -1 is deprecated. Pass `None` to not specify a seed.", - FutureWarning, + raise ValueError( + "Setting random_seed = -1 is not allowed. Pass `None` to not specify a seed." ) - random_seed = None elif isinstance(random_seed, tuple | list): warnings.warn( "A list or tuple of random_seed no longer specifies the specific random_seed of each chain. " diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6f..6e229b9606 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,10 +33,9 @@ from pymc.backends.zarr import ZarrChain from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError +from pymc.progress_bar import ProgressBarManager, default_progress_theme from pymc.util import ( - ProgressBarManager, RandomGeneratorState, - default_progress_theme, get_state_from_generator, random_generator_from_state, ) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 92de63d0c2..5bd1771704 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -30,6 +30,7 @@ from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.stats.convergence import log_warning_stats from pymc.step_methods import CompoundStep from pymc.step_methods.arraystep import ( @@ -39,7 +40,6 @@ ) from pymc.step_methods.compound import StepMethodState from pymc.step_methods.metropolis import DEMetropolis -from pymc.util import CustomProgress __all__ = () diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index f3176f464b..5afd398281 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -39,10 +39,11 @@ from pymc.distributions.distribution import _support_point from pymc.logprob.abstract import _icdf, _logcdf, _logprob from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH from pymc.stats.convergence import log_warnings, run_convergence_checks -from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain +from pymc.util import RandomState, _get_seeds_per_chain def sample_smc( diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index d32831c8be..7dc880ba0b 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -144,9 +144,15 @@ def warn_divergences(idata: arviz.InferenceData) -> list[SamplerWarning]: n_div = int(diverging.sum()) if n_div == 0: return [] + + if n_div == 1: + verb, word = "was", "divergence" + else: + verb, word = "were", "divergences" + warning = SamplerWarning( WarningType.DIVERGENCES, - f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.", + f"There {verb} {n_div} {word} after tuning. Increase `target_accept` or reparameterize.", "error", ) return [warning] diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..a9cae903f0 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return step_stats - return update_stats + return (update_stats,) # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): @@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): - update_fns = [method._make_update_stats_function() for method in self.methods] - - def update_stats(stats, step_stats, chain_idx): - for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) - - return stats - - return update_stats + def _make_progressbar_update_functions(self): + update_functions = [] + for method in self.methods: + update_functions.extend(method._make_progressbar_update_functions()) + return update_functions def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4b..297b095e23 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -184,6 +184,7 @@ def __init__( self._step_rand = step_rand self._num_divs_sample = 0 + self.divergences = 0 @abstractmethod def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: @@ -266,11 +267,15 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: divergence_info=info_store, ) + diverging = bool(hmc_step.divergence_info) + if not self.tune: + self.divergences += diverging self.iter_count += 1 stats: dict[str, Any] = { "tune": self.tune, - "diverging": bool(hmc_step.divergence_info), + "diverging": diverging, + "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, @@ -288,6 +293,8 @@ def reset_tuning(self, start=None): self.reset(start=None) def reset(self, start=None): + self.iter_count = 0 + self.divergences = 0 self.tune = True self.potential.reset() diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 565c1fd78b..1697341bc8 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -19,6 +19,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData @@ -55,6 +58,7 @@ class HamiltonianMC(BaseHMC): "accept": (np.float64, []), "diverging": (bool, []), "energy_error": (np.float64, []), + "divergences": (np.int64, []), "energy": (np.float64, []), "path_length": (np.float64, []), "accepted": (bool, []), @@ -202,3 +206,32 @@ def competence(var, has_grad): if var.dtype in discrete_types or not has_grad: return Competence.INCOMPATIBLE return Competence.COMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "n_steps": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_progressbar_update_functions(): + def update_stats(stats): + return { + key: stats[key] + for key in ( + "divergences", + "n_steps", + ) + } | { + "failing": stats["divergences"] > 0, + } + + return (update_stats,) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c3592..0f19d3c087 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -115,6 +115,7 @@ class NUTS(BaseHMC): "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), "diverging": (bool, []), + "divergences": (int, []), "energy_error": (np.float64, []), "energy": (np.float64, []), "max_energy_error": (np.float64, []), @@ -248,19 +249,13 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(stats): + return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | { + "failing": stats["divergences"] > 0, + } - if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] - - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats - - return update_stats + return (update_stats,) # A proposal for the next position diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..2cd2e1369e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return { + "accept_rate" if key == "accept" else key: step_stats[key] + for key in ("tune", "accept", "scaling") + } - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] - - return stats - - return update_stats + return (update_stats,) def tune(scale, acc_rate): @@ -684,7 +680,6 @@ def competence(var): class CategoricalGibbsMetropolisState(StepMethodState): shuffle_dims: bool dimcats: list[tuple] - tune: bool class CategoricalGibbsMetropolis(ArrayStep): @@ -767,10 +762,6 @@ def __init__( else: raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. - self.tune = True - if compile_kwargs is None: compile_kwargs = {} super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) @@ -800,10 +791,8 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -820,7 +809,8 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType for dim, k in dimcats: logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) - return q, [] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..180ac1c882 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - - return stats - - return update_stats + return (update_stats,) diff --git a/pymc/testing.py b/pymc/testing.py index a5fdc28327..b016c25ad1 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -13,7 +13,6 @@ # limitations under the License. import functools as ft import itertools as it -import warnings from collections.abc import Callable, Sequence from typing import Any @@ -665,14 +664,6 @@ def check_selfconsistency_discrete_logcdf( ) -def assert_moment_is_expected(model, expected, check_finite_logp=True): - warnings.warn( - "assert_moment_is_expected is deprecated. Use assert_support_point_is_expected instead.", - FutureWarning, - ) - assert_support_point_is_expected(model, expected, check_finite_logp=check_finite_logp) - - def assert_support_point_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 2fbbba6339..1385f33483 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -36,9 +36,8 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.util import ( - CustomProgress, - default_progress_theme, get_default_varnames, get_value_vars_from_user_vars, ) diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..3f108b8b03 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,12 +14,11 @@ import functools import re -import warnings from collections import namedtuple -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Literal, NewType, cast +from typing import NewType, cast import arviz import cloudpickle @@ -29,61 +28,11 @@ from cachetools import LRUCache, cachedmethod from pytensor import Variable from pytensor.compile import SharedVariable -from pytensor.graph.utils import ValidatingScratchpad -from rich.box import SIMPLE_HEAD -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - Task, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) -from rich.style import Style -from rich.table import Column, Table -from rich.theme import Theme from pymc.exceptions import BlockModelAccessError -if TYPE_CHECKING: - from pymc.step_methods.compound import BlockedStep, CompoundStep - - -ProgressBarType = Literal[ - "combined", - "split", - "combined+stats", - "stats+combined", - "split+stats", - "stats+split", -] - - -def __getattr__(name): - if name == "dataset_to_point_list": - warnings.warn( - f"{name} has been moved to backends.arviz. Importing from util will fail in a future release.", - FutureWarning, - ) - from pymc.backends.arviz import dataset_to_point_list - - return dataset_to_point_list - - raise AttributeError(f"module {__name__} has no attribute {name}") - - VarName = NewType("VarName", str) -default_progress_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - "progress.remaining": "none", - "progress.elapsed": "none", - } -) - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" @@ -540,34 +489,6 @@ def get_value_vars_from_user_vars(vars: Variable | Sequence[Variable], model) -> return value_vars -class _FutureWarningValidatingScratchpad(ValidatingScratchpad): - def __getattribute__(self, name): - for deprecated_names, alternative in ( - (("value_var", "observations"), "model.rvs_to_values[rv]"), - (("transform",), "model.rvs_to_transforms[rv]"), - ): - if name in deprecated_names: - try: - super().__getattribute__(name) - except AttributeError: - pass - else: - warnings.warn( - f"The tag attribute {name} is deprecated. Use {alternative} instead", - FutureWarning, - ) - return super().__getattribute__(name) - - -def _add_future_warning_tag(var) -> None: - old_tag = var.tag - if not isinstance(old_tag, _FutureWarningValidatingScratchpad): - new_tag = _FutureWarningValidatingScratchpad("test_value", var.type.filter) - for k, v in old_tag.__dict__.items(): - new_tag.__dict__.setdefault(k, v) - var.tag = new_tag - - def makeiter(a): if isinstance(a, tuple | list): return a @@ -575,379 +496,6 @@ def makeiter(a): return [a] -class CustomProgress(Progress): - """A child of Progress that allows to disable progress bars and its container. - - The implementation simply checks an `is_enabled` flag and generates the progress bar only if - it's `True`. - """ - - def __init__(self, *args, disable=False, include_headers=False, **kwargs): - self.is_enabled = not disable - self.include_headers = include_headers - - if self.is_enabled: - super().__init__(*args, **kwargs) - - def __enter__(self): - """Enter the context manager.""" - if self.is_enabled: - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Exit the context manager.""" - if self.is_enabled: - super().__exit__(exc_type, exc_val, exc_tb) - - def add_task(self, *args, **kwargs): - if self.is_enabled: - return super().add_task(*args, **kwargs) - return None - - def advance(self, task_id, advance=1) -> None: - if self.is_enabled: - super().advance(task_id, advance) - return None - - def update( - self, - task_id, - *, - total=None, - completed=None, - advance=None, - description=None, - visible=None, - refresh=False, - **fields, - ): - if self.is_enabled: - super().update( - task_id, - total=total, - completed=completed, - advance=advance, - description=description, - visible=visible, - refresh=refresh, - **fields, - ) - return None - - def make_tasks_table(self, tasks: Iterable[Task]) -> Table: - """Get a table to render the Progress display. - - Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. - - Parameters - ---------- - tasks: Iterable[Task] - An iterable of Task instances, one per row of the table. - - Returns - ------- - table: Table - A table instance. - """ - - def call_column(column, task): - # Subclass rich.BarColumn and add a callback method to dynamically update the display - if hasattr(column, "callbacks"): - column.callbacks(task) - - return column(task) - - table_columns = ( - ( - Column(no_wrap=True) - if isinstance(_column, str) - else _column.get_table_column().copy() - ) - for _column in self.columns - ) - if self.include_headers: - table = Table( - *table_columns, - padding=(0, 1), - expand=self.expand, - show_header=True, - show_edge=True, - box=SIMPLE_HEAD, - ) - else: - table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) - - for task in tasks: - if task.visible: - table.add_row( - *( - ( - column.format(task=task) - if isinstance(column, str) - else call_column(column, task) - ) - for column in self.columns - ) - ) - - return table - - -class DivergenceBarColumn(BarColumn): - """Rich colorbar that changes color when a chain has detected a divergence.""" - - def __init__(self, *args, diverging_color="red", **kwargs): - from matplotlib.colors import to_rgb - - self.diverging_color = diverging_color - self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] - - super().__init__(*args, **kwargs) - - self.non_diverging_style = self.complete_style - self.non_diverging_finished_style = self.finished_style - - def callbacks(self, task: "Task"): - divergences = task.fields.get("divergences", 0) - if isinstance(divergences, float | int) and divergences > 0: - self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - else: - self.complete_style = self.non_diverging_style - self.finished_style = self.non_diverging_finished_style - - -class ProgressBarManager: - """Manage progress bars displayed during sampling.""" - - def __init__( - self, - step_method: "BlockedStep | CompoundStep", - chains: int, - draws: int, - tune: int, - progressbar: bool | ProgressBarType = True, - progressbar_theme: Theme | None = None, - ): - """ - Manage progress bars displayed during sampling. - - When sampling, Step classes are responsible for computing and exposing statistics that can be reported on - progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` - and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which - columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics - that will be displayed on the progress bar. - - Parameters - ---------- - step_method: BlockedStep or CompoundStep - The step method being used to sample - chains: int - Number of chains being sampled - draws: int - Number of draws per chain - tune: int - Number of tuning steps per chain - progressbar: bool or ProgressType, optional - How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask - for one of the following: - - "combined": A single progress bar that displays the total progress across all chains. Only timing - information is shown. - - "split": A separate progress bar for each chain. Only timing information is shown. - - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all - chains. Aggregate sample statistics are also displayed. - - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain - are also displayed. - - If True, the default is "split+stats" is used. - - progressbar_theme: Theme, optional - The theme to use for the progress bar. Defaults to the default theme. - """ - if progressbar_theme is None: - progressbar_theme = default_progress_theme - - match progressbar: - case True: - self.combined_progress = False - self.full_stats = True - show_progress = True - case False: - self.combined_progress = False - self.full_stats = True - show_progress = False - case "combined": - self.combined_progress = True - self.full_stats = False - show_progress = True - case "split": - self.combined_progress = False - self.full_stats = False - show_progress = True - case "combined+stats" | "stats+combined": - self.combined_progress = True - self.full_stats = True - show_progress = True - case "split+stats" | "stats+split": - self.combined_progress = False - self.full_stats = True - show_progress = True - case _: - raise ValueError( - "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " - "one of 'combined', 'split', 'split+stats', or 'combined+stats." - ) - - progress_columns, progress_stats = step_method._progressbar_config(chains) - - self._progress = self.create_progress_bar( - progress_columns, - progressbar=progressbar, - progressbar_theme=progressbar_theme, - ) - - self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stats_function() - - self._show_progress = show_progress - self.divergences = 0 - self.completed_draws = 0 - self.total_draws = draws + tune - self.desc = "Sampling chain" - self.chains = chains - - self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] - - def __enter__(self): - self._initialize_tasks() - - return self._progress.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._progress.__exit__(exc_type, exc_val, exc_tb) - - def _initialize_tasks(self): - if self.combined_progress: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws * self.chains - 1, - chain_idx=0, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[0] for stat, value in self.progress_stats.items()}, - ) - ] - - else: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws - 1, - chain_idx=chain_idx, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, - ) - for chain_idx in range(self.chains) - ] - - def compute_draw_speed(self, chain_idx, draws): - elapsed = self._progress.tasks[chain_idx].elapsed - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit - - def update(self, chain_idx, is_last, draw, tuning, stats): - if not self._show_progress: - return - - self.completed_draws += 1 - if self.combined_progress: - draw = self.completed_draws - chain_idx = 0 - - speed, unit = self.compute_draw_speed(chain_idx, draw) - - if not tuning and stats and stats[0].get("diverging"): - self.divergences += 1 - - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) - more_updates = ( - {stat: value[chain_idx] for stat, value in self.progress_stats.items()} - if self.full_stats - else {} - ) - - self._progress.update( - self.tasks[chain_idx], - completed=draw, - draws=draw, - sampling_speed=speed, - speed_unit=unit, - **more_updates, - ) - - if is_last: - self._progress.update( - self.tasks[chain_idx], - draws=draw + 1 if not self.combined_progress else draw, - **more_updates, - refresh=True, - ) - - def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] - - if self.full_stats: - columns += step_columns - - columns += [ - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), - TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), - ] - - return CustomProgress( - DivergenceBarColumn( - table_column=Column("Progress", ratio=2), - diverging_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(31,119,180)"), # tab:blue - ), - *columns, - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, - ) - - -def compute_draw_speed(elapsed, draws): - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit - - RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index d9da7fb786..b83c1db4a3 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -23,7 +23,7 @@ import pymc as pm -from pymc.util import CustomProgress, default_progress_theme +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD diff --git a/requirements-dev.txt b/requirements-dev.txt index 840f3d8063..cabfc740af 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,7 +4,6 @@ arviz>=0.13.0 cachetools>=4.2.1 cloudpickle -git+https://github.com/pymc-devs/pymc-sphinx-theme ipython>=7.16 jupyter-sphinx mcbackend>=0.4.0 @@ -16,6 +15,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 +pymc-sphinx-theme>=0.16.0 pytensor>=2.31.2,<2.32 pytest-cov>=2.5 pytest>=3.0 diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 409e255d75..032fbc938b 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -168,7 +168,7 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): for section, sdf in df.reset_index().groupby(args.groupby): print(f"\n\n[{section}]") for row in sdf.itertuples(): - print(f"{row.file}:{row.line}: {row.type}: {row.message}") + print(f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}") print() else: print( diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 6e8b0f9dcd..21dce537b0 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -17,6 +17,7 @@ import pymc as pm +from pymc import logp from pymc.distributions.shape_utils import change_dist_size @@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self): pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2)) ) assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2) + + def test_censored_categorical(self): + cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,)) + + np.testing.assert_allclose( + logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0.1, 0.2, 0.2, 0.3, 0.2, 0], + ) + + censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,)) + + np.testing.assert_allclose( + logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0, 0.3, 0.2, 0.5, 0, 0], + ) diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py index dba68c26e6..7e05b2d02d 100644 --- a/tests/distributions/test_custom.py +++ b/tests/distributions/test_custom.py @@ -196,15 +196,6 @@ def test_custom_dist_default_support_point_univariate(self, support_point, size, assert isinstance(x.owner.op, CustomDistRV) assert_support_point_is_expected(model, expected, check_finite_logp=False) - def test_custom_dist_moment_future_warning(self): - moment = lambda rv, size, *rv_inputs: 5 * pt.ones(size, dtype=rv.dtype) # noqa: E731 - with Model() as model: - with pytest.warns( - FutureWarning, match="`moment` argument is deprecated. Use `support_point` instead." - ): - x = CustomDist("x", moment=moment, size=()) - assert_support_point_is_expected(model, 5, check_finite_logp=False) - @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) def test_custom_dist_custom_support_point_univariate(self, size): def density_support_point(rv, size, mu): diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 24eeb504c9..55e8c23128 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -367,43 +367,58 @@ def test_poisson(self): @pytest.mark.parametrize("n", [2, 3, 4]) def test_categorical(self, n): + domain = Domain(range(n), dtype="int64", edges=(0, n)) + paramdomains = {"p": Simplex(n)} + check_logp( pm.Categorical, - Domain(range(n), dtype="int64", edges=(0, n)), - {"p": Simplex(n)}, + domain, + paramdomains, lambda value, p: categorical_logpdf(value, p), ) - def test_categorical_logp_batch_dims(self): + check_selfconsistency_discrete_logcdf( + pm.Categorical, + domain, + paramdomains, + ) + + @pytest.mark.parametrize("method", (logp, logcdf), ids=lambda x: x.__name__) + def test_categorical_logp_batch_dims(self, method): # Core case p = np.array([0.2, 0.3, 0.5]) value = np.array(2.0) - logp_expr = logp(pm.Categorical.dist(p=p, shape=value.shape), value) - assert logp_expr.type.ndim == 0 - np.testing.assert_allclose(logp_expr.eval(), np.log(0.5)) + expr = method(pm.Categorical.dist(p=p, shape=value.shape), value) + assert expr.type.ndim == 0 + expected_p = 0.5 if method is logp else 1.0 + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Explicit batched value broadcasts p bcast_p = p[None] # shape (1, 3) batch_value = np.array([0, 1]) # shape(3,) - logp_expr = logp(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3])) + expr = method(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) + + # Implicit batch value broadcasts p + expr = method(pm.Categorical.dist(p=p, shape=()), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Explicit batched value and batched p batch_p = np.array([p[::-1], p]) - logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.5, 0.3])) - - # Implicit batch value broadcasts p - logp_expr = logp(pm.Categorical.dist(p=p, shape=()), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3])) + expr = method(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.5, 0.3] if method is logp else [0.5, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Implicit batch p broadcasts value - logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=None), value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.5])) + expr = method(pm.Categorical.dist(p=batch_p, shape=None), value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.5] if method is logp else [1.0, 1.0] + np.testing.assert_allclose(expr.exp().eval(), expected_p) @pytensor.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): diff --git a/tests/distributions/test_dist_math.py b/tests/distributions/test_dist_math.py index 39b9cfdd04..ff213954a9 100644 --- a/tests/distributions/test_dist_math.py +++ b/tests/distributions/test_dist_math.py @@ -30,7 +30,6 @@ clipped_beta_rvs, factln, i0e, - incomplete_beta, multigammaln, ) from pymc.logprob.utils import ParameterValueError @@ -176,9 +175,3 @@ def ref_multigammaln(a, b): for x in xvals: if np.all(x > 0.5 * (p - 1)): check_vals(multigammaln_, ref_multigammaln, x, p) - - -def test_incomplete_beta_deprecation(): - with pytest.warns(FutureWarning, match="incomplete_beta has been deprecated"): - res = incomplete_beta(3, 5, 0.5).eval() - assert np.isclose(res, pt.betainc(3, 5, 0.5).eval()) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index d7e2bbd0a1..f84fb8f869 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -52,7 +52,6 @@ check_logcdf, check_logp, ) -from pymc.util import _FutureWarningValidatingScratchpad class TestBugfixes: @@ -92,8 +91,14 @@ def test_all_distributions_have_support_points(): dists = (getattr(dist_module, dist) for dist in dist_module.__all__) dists = (dist for dist in dists if isinstance(dist, DistributionMeta)) + generic_func = _support_point.dispatch(object) missing_support_points = { - dist for dist in dists if getattr(dist, "rv_type", None) not in _support_point.registry + dist + for dist in dists + if ( + getattr(dist, "rv_type", None) is not None + and _support_point.dispatch(dist.rv_type) is generic_func + ) } # Ignore super classes @@ -233,43 +238,6 @@ def rv_op(cls, size=None, rng=None): assert resized_rv.type.shape == (5,) -def test_tag_future_warning_dist(): - # Test no unexpected warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - - x = pm.Normal.dist() - assert isinstance(x.tag, _FutureWarningValidatingScratchpad) - - x.tag.banana = "banana" - assert x.tag.banana == "banana" - - # Check we didn't break test_value filtering - x.tag.test_value = np.array(1) - assert x.tag.test_value == 1 - with pytest.raises(TypeError, match="Wrong number of dimensions"): - x.tag.test_value = np.array([1, 1]) - assert x.tag.test_value == 1 - - # No warning if deprecated attribute is not present - with pytest.raises(AttributeError): - x.tag.value_var - - # Warning if present - x.tag.value_var = "1" - with pytest.warns(FutureWarning, match="Use model.rvs_to_values"): - value_var = x.tag.value_var - assert value_var == "1" - - # Check that PyMC method that copies tag contents does not erase special tag - new_x = change_dist_size(x, new_size=5) - assert new_x.tag is not x.tag - assert isinstance(new_x.tag, _FutureWarningValidatingScratchpad) - with pytest.warns(FutureWarning, match="Use model.rvs_to_values"): - value_var = new_x.tag.value_var - assert value_var == "1" - - def test_distribution_op_registered(): """Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions.""" assert isinstance(Normal.dist().owner.op, Normal) diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 7fd00bcb5a..28a09744b6 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -49,13 +49,12 @@ Poisson, StickBreakingWeights, Triangular, - Truncated, Uniform, ZeroInflatedBinomial, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, ) -from pymc.distributions.mixture import MixtureTransformWarning +from pymc.distributions.mixture import MixtureTransformWarning, _Hurdle from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.distributions.transforms import _default_transform from pymc.logprob.basic import logp @@ -1563,26 +1562,24 @@ def test_zero_inflated_dists_dtype_and_broadcast(self, dist, non_psi_args): assert x.eval().shape == (3,) -class TestHurdleMixtures: +class TestHurdleDistributions: @staticmethod - def check_hurdle_mixture_graph(dist): + def check_hurdle_graph(dist): # Assert it's a mixture - assert isinstance(dist.owner.op, Mixture) + assert isinstance(dist.owner.op, _Hurdle) # Extract the distribution for zeroes and nonzeroes zero_dist, nonzero_dist = dist.owner.inputs[-2:] # Assert ops are of the right type assert isinstance(zero_dist.owner.op, DiracDelta) - assert isinstance(nonzero_dist.owner.op, Truncated) - return zero_dist, nonzero_dist def test_hurdle_poisson_graph(self): # There's nothing special in these values psi, mu = 0.3, 4 dist = HurdlePoisson.dist(psi=psi, mu=mu) - _, nonzero_dist = self.check_hurdle_mixture_graph(dist) + _, nonzero_dist = self.check_hurdle_graph(dist) # Assert the truncated distribution is of the right type assert isinstance(nonzero_dist.owner.op.base_rv_op, Poisson) @@ -1593,7 +1590,7 @@ def test_hurdle_poisson_graph(self): def test_hurdle_negativebinomial_graph(self): psi, p, n = 0.2, 0.6, 10 dist = HurdleNegativeBinomial.dist(psi=psi, p=p, n=n) - _, nonzero_dist = self.check_hurdle_mixture_graph(dist) + _, nonzero_dist = self.check_hurdle_graph(dist) assert isinstance(nonzero_dist.owner.op.base_rv_op, NegativeBinomial) assert nonzero_dist.owner.inputs[-4].data == n @@ -1602,22 +1599,24 @@ def test_hurdle_negativebinomial_graph(self): def test_hurdle_gamma_graph(self): psi, alpha, beta = 0.25, 3, 4 dist = HurdleGamma.dist(psi=psi, alpha=alpha, beta=beta) - _, nonzero_dist = self.check_hurdle_mixture_graph(dist) + _, nonzero_dist = self.check_hurdle_graph(dist) # Under the hood it uses the shape-scale parametrization of the Gamma distribution. # So the second value is the reciprocal of the rate (i.e. 1 / beta) - assert isinstance(nonzero_dist.owner.op.base_rv_op, Gamma) - assert nonzero_dist.owner.inputs[-4].data == alpha - assert nonzero_dist.owner.inputs[-3].eval() == 1 / beta + assert isinstance(nonzero_dist.owner.op, Gamma) + alpha_param, reciprocal_beta_param = nonzero_dist.owner.op.dist_params(nonzero_dist.owner) + assert alpha_param.data == alpha + assert reciprocal_beta_param.eval() == 1 / beta def test_hurdle_lognormal_graph(self): psi, mu, sigma = 0.1, 2, 2.5 dist = HurdleLogNormal.dist(psi=psi, mu=mu, sigma=sigma) - _, nonzero_dist = self.check_hurdle_mixture_graph(dist) + _, nonzero_dist = self.check_hurdle_graph(dist) - assert isinstance(nonzero_dist.owner.op.base_rv_op, LogNormal) - assert nonzero_dist.owner.inputs[-4].data == mu - assert nonzero_dist.owner.inputs[-3].data == sigma + assert isinstance(nonzero_dist.owner.op, LogNormal) + mu_param, sigma_param = nonzero_dist.owner.op.dist_params(nonzero_dist.owner) + assert mu_param.data == mu + assert sigma_param.data == sigma @pytest.mark.parametrize( "dist, psi, non_psi_args", @@ -1699,11 +1698,7 @@ def logp_fn(value, psi, alpha, beta): if value == 0: return np.log(1 - psi) else: - return ( - np.log(psi) - + st.gamma.logpdf(value, alpha, scale=1.0 / beta) - - np.log(1 - st.gamma.cdf(np.finfo(float).eps, alpha, scale=1.0 / beta)) - ) + return np.log(psi) + st.gamma.logpdf(value, alpha, scale=1.0 / beta) check_logp(HurdleGamma, Rplus, {"psi": Unit, "alpha": Rplusbig, "beta": Rplusbig}, logp_fn) @@ -1712,10 +1707,6 @@ def logp_fn(value, psi, mu, sigma): if value == 0: return np.log(1 - psi) else: - return ( - np.log(psi) - + st.lognorm.logpdf(value, sigma, 0, np.exp(mu)) - - np.log(1 - st.lognorm.cdf(np.finfo(float).eps, sigma, 0, np.exp(mu))) - ) + return np.log(psi) + st.lognorm.logpdf(value, sigma, 0, np.exp(mu)) check_logp(HurdleLogNormal, Rplus, {"psi": Unit, "mu": R, "sigma": Rplusbig}, logp_fn) diff --git a/tests/distributions/test_timeseries.py b/tests/distributions/test_timeseries.py index 197296a5a4..e9d3e76159 100644 --- a/tests/distributions/test_timeseries.py +++ b/tests/distributions/test_timeseries.py @@ -461,26 +461,6 @@ def test_mvstudentt(self, param): assert isinstance(init_dist.owner.op, Dirichlet) assert isinstance(innovation_dist.owner.op, MvStudentT) - @pytest.mark.parametrize( - "distribution, init_dist, build_kwargs", - [ - (GaussianRandomWalk, Normal.dist(), {}), - ( - MvGaussianRandomWalk, - Dirichlet.dist(np.ones(3)), - {"mu": np.zeros(3), "tau": np.eye(3)}, - ), - ( - MvStudentTRandomWalk, - Dirichlet.dist(np.ones(3)), - {"nu": 4, "mu": np.zeros(3), "tau": np.eye(3)}, - ), - ], - ) - def test_init_deprecated_arg(self, distribution, init_dist, build_kwargs): - with pytest.warns(FutureWarning, match="init parameter is now called init_dist"): - distribution.dist(init=init_dist, steps=10, **build_kwargs) - class TestAR: def test_order1_logp(self): @@ -713,10 +693,6 @@ def test_support_point(self, size, expected): AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size) assert_support_point_is_expected(model, expected, check_finite_logp=False) - def test_init_deprecated_arg(self): - with pytest.warns(FutureWarning, match="init parameter is now called init_dist"): - AR.dist(rho=[1, 2, 3], init=Normal.dist(), shape=(10,)) - def test_change_dist_size(self): base_dist = AR.dist(rho=[0.5, 0.5], init_dist=pm.Normal.dist(size=(2,)), shape=(3, 10)) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index e28052bab9..26bc8b1bf8 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -151,9 +151,6 @@ def test_sum_to_1(): check_vector_transform(tr.sum_to_1, Simplex(2)) check_vector_transform(tr.sum_to_1, Simplex(4)) - with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): - tr.SumTo1(2) - check_jacobian_det( tr.sum_to_1, Vector(Unit, 2), @@ -161,13 +158,6 @@ def test_sum_to_1(): floatX(np.array([0, 0])), lambda x: x[:-1], ) - check_jacobian_det( - tr.multivariate_sum_to_1, - Vector(Unit, 2), - pt.vector, - floatX(np.array([0, 0])), - lambda x: x[:-1], - ) def test_log(): @@ -271,9 +261,6 @@ def test_circular(): def test_ordered(): check_vector_transform(tr.ordered, SortedVector(6)) - with pytest.warns(FutureWarning, match="ndim_supp argument is deprecated"): - tr.Ordered(1) - check_jacobian_det( tr.ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False ) @@ -678,23 +665,3 @@ def log_jac_det(self, value, *inputs): match="are not allowed to broadcast together. There is a bug in the implementation of either one", ): m.logp(jacobian=jacobian_val) - - -def test_deprecated_ndim_supp_transforms(): - with pytest.warns(FutureWarning, match="deprecated"): - tr.Ordered(ndim_supp=1) - - with pytest.warns(FutureWarning, match="deprecated"): - assert tr.univariate_ordered == tr.ordered - - with pytest.warns(FutureWarning, match="deprecated"): - assert tr.multivariate_ordered == tr.ordered - - with pytest.warns(FutureWarning, match="deprecated"): - tr.SumTo1(ndim_supp=1) - - with pytest.warns(FutureWarning, match="deprecated"): - assert tr.univariate_sum_to_1 == tr.sum_to_1 - - with pytest.warns(FutureWarning, match="deprecated"): - assert tr.multivariate_sum_to_1 == tr.sum_to_1 diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index f6014d78a3..5e17dd8a8e 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -201,7 +201,7 @@ def test_persist_inputs(): assert y_vv_2 in ancestors([logp_2_combined]) -def test_warn_random_found_factorized_joint_logprob(): +def test_warn_rvs_conditional_logp(): x_rv = pt.random.normal(name="x") y_rv = pt.random.normal(x_rv, 1, name="y") @@ -369,7 +369,7 @@ def test_probability_inference_fails(func, func_name): (icdf, "ppf", 0.7), ], ) -def test_warn_random_found_probability_inference(func, scipy_func, test_value): +def test_warn_rvs_probability_derivation(func, scipy_func, test_value): # Fail if unexpected warning is issued with warnings.catch_warnings(): warnings.simplefilter("error") @@ -436,3 +436,26 @@ def test_ir_rewrite_does_not_disconnect_valued_rvs(): logp_b.eval({a_value: np.pi, b_value: np.e}), stats.norm.logpdf(np.e, np.pi * 8, 1), ) + + +def test_ir_ops_can_be_evaluated_with_warning(): + _eval_values = [None, None] + + def my_logp(value, lam): + nonlocal _eval_values + _eval_values[0] = value.eval() + _eval_values[1] = lam.eval({"lam_log__": -1.5}) + return value * lam + + with pm.Model() as m: + lam = pm.Exponential("lam") + pm.CustomDist("y", lam, logp=my_logp, observed=[0, 1, 2]) + + with pytest.warns( + UserWarning, match="TransformedValue should not be present in the final graph" + ): + with pytest.warns(UserWarning, match="ValuedVar should not be present in the final graph"): + m.logp() + + assert _eval_values[0].sum() == 3 + assert _eval_values[1] == np.exp(-1.5) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 2ab30235bd..b55699b569 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -249,7 +249,7 @@ class TestLocScaleRVTransform: @pytest.mark.parametrize( "rv_size, loc_type, addition", [ - (None, pt.scalar, True), + ((), pt.scalar, True), (2, pt.vector, False), ((2, 1), pt.col, True), ], @@ -288,7 +288,7 @@ def test_loc_transform_rv(self, rv_size, loc_type, addition): @pytest.mark.parametrize( "rv_size, scale_type, product", [ - (None, pt.scalar, True), + ((), pt.scalar, True), (1, pt.TensorType("floatX", (True,)), True), ((2, 3), pt.matrix, False), ], diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 4375a17ad2..814cb114d4 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -33,6 +33,7 @@ import scipy.stats as st from pytensor.graph import graph_inputs +from pytensor.graph.basic import get_var_by_name from pytensor.raise_op import Assert from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorConstant @@ -55,7 +56,6 @@ from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext -from pymc.util import _FutureWarningValidatingScratchpad from pymc.variational.minibatch_rv import MinibatchRandomVariable from tests.models import simple_model @@ -858,6 +858,19 @@ def test_nested_model_coords(): assert set(m2.named_vars_to_dims) < set(m1.named_vars_to_dims) +def test_multiple_add_coords_with_same_name(): + coord = {"dim1": ["a", "b", "c"]} + with pm.Model(coords=coord) as m: + a = pm.Normal("a", dims="dim1") + with pm.Model(coords=coord) as nested_m: + b = pm.Normal("b", dims="dim1") + m.add_coords(coord) + c = pm.Normal("c", dims="dim1") + d = pm.Deterministic("d", a + b + c) + variables = get_var_by_name([d], "dim1") + assert len(variables) == 1 and variables[0] is m.dim_lengths["dim1"] + + class TestSetUpdateCoords: def test_shapeerror_from_set_data_dimensionality(self): with pm.Model() as pmodel: @@ -1634,55 +1647,6 @@ def test_deterministic(self): ) -def test_tag_future_warning_model(): - # Test no unexpected warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - - model = pm.Model() - - x = pt.random.normal() - x.tag.something_else = "5" - x.tag.test_value = 0 - assert not isinstance(x.tag, _FutureWarningValidatingScratchpad) - - # Test that model changes the tag type, but copies existing contents - x = model.register_rv(x, name="x", transform=log) - assert isinstance(x.tag, _FutureWarningValidatingScratchpad) - assert x.tag.something_else == "5" - assert x.tag.test_value == 0 - - # Test expected warnings - with pytest.warns(FutureWarning, match="model.rvs_to_values"): - x_value = x.tag.value_var - - assert isinstance(x_value.tag, _FutureWarningValidatingScratchpad) - with pytest.warns(FutureWarning, match="model.rvs_to_transforms"): - transform = x_value.tag.transform - assert transform is log - - with pytest.raises(AttributeError): - x.tag.observations - - # Cloning a node will keep the same tag type and contents - y = x.owner.clone().default_output() - assert y is not x - assert y.tag is not x.tag - assert isinstance(y.tag, _FutureWarningValidatingScratchpad) - y = model.register_rv(y, name="y", observed=5) - assert isinstance(y.tag, _FutureWarningValidatingScratchpad) - - # Test expected warnings - with pytest.warns(FutureWarning, match="model.rvs_to_values"): - y_value = y.tag.value_var - with pytest.warns(FutureWarning, match="model.rvs_to_values"): - y_obs = y.tag.observations - assert y_value is y_obs - assert y_value.eval() == 5 - - assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) - - class TestModelDebug: @pytest.mark.parametrize("fn", ("logp", "dlogp", "random")) def test_no_problems(self, fn, capfd): @@ -1786,7 +1750,7 @@ def school_model(J: int) -> pm.Model: ) def test_graphviz_call_function(self, var_names, filenames) -> None: model = self.school_model(J=8) - with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz: + with patch("pymc.model_graph.model_to_graphviz") as mock_model_to_graphviz: model.to_graphviz(var_names=var_names, save=filenames) mock_model_to_graphviz.assert_called_once_with( model=model, diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py index 178eb39683..a3f04e3ce0 100644 --- a/tests/model/test_fgraph.py +++ b/tests/model/test_fgraph.py @@ -108,7 +108,7 @@ def test_data(inline_views): with pm.Model(coords={"test_dim": range(3)}) as m_old: x = pm.Data("x", [0.0, 1.0, 2.0], dims=("test_dim",)) y = pm.Data("y", [10.0, 11.0, 12.0], dims=("test_dim",)) - sigma = pm.MutableData("sigma", [1.0], shape=(1,)) + sigma = pm.Data("sigma", [1.0], shape=(1,)) b0 = pm.Data("b0", np.zeros((1,)), shape=((1,))) b1 = pm.Normal("b1", 1.0, sigma=1e-8) mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) diff --git a/tests/model/transform/test_conditioning.py b/tests/model/transform/test_conditioning.py index 828fac737b..fa9ce71246 100644 --- a/tests/model/transform/test_conditioning.py +++ b/tests/model/transform/test_conditioning.py @@ -29,6 +29,7 @@ observe, remove_value_transforms, ) +from pymc.model.transform.optimization import freeze_dims_and_data from pymc.variational.minibatch_rv import create_minibatch_rv @@ -176,10 +177,13 @@ def test_do_posterior_predictive(): def test_do_constant(mutable): rng = np.random.default_rng(seed=122) with pm.Model() as m: - x = pm.Data("x", 0, mutable=mutable) + x = pm.Data("x", 0) y = pm.Normal("y", x, 1e-3) - do_m = do(m, {x: 105}) + if not mutable: + m = freeze_dims_and_data(m, data=["x"]) + + do_m = do(m, {m["x"]: 105}) assert pm.draw(do_m["y"], random_seed=rng) > 100 diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index d3b41bf667..df8bb2dbf2 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -28,17 +28,22 @@ from pytensor import Mode, shared from pytensor.compile import SharedVariable from pytensor.graph import graph_inputs +from pytensor.graph.basic import get_var_by_name, variable_depends_on +from pytensor.tensor.variable import TensorConstant from scipy import stats import pymc as pm from pymc.backends.base import MultiTrace +from pymc.logprob.utils import rvs_in_graph +from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import compile from pymc.sampling.forward import ( compile_forward_sampling_function, get_constant_coords, get_vars_in_point_list, observed_dependent_deterministics, + vectorize_over_posterior, ) from pymc.testing import fast_unstable_sampling_mode @@ -1801,3 +1806,156 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None: match = "The samples argument has been deprecated" with pytest.warns(DeprecationWarning, match=match): pm.sample_prior_predictive(model=m, samples=10) + + +@pytest.fixture(params=["deterministic", "observed", "conditioned_on_observed"]) +def variable_to_vectorize(request): + if request.param == "deterministic": + return ["y"] + elif request.param == "conditioned_on_observed": + return ["z", "z_downstream"] + else: + return ["z"] + + +@pytest.fixture(params=["allow_rvs_in_graph", "disallow_rvs_in_graph"]) +def allow_rvs_in_graph(request): + if request.param == "allow_rvs_in_graph": + return True + else: + return False + + +@pytest.fixture(scope="module", params=["nested_random_variables", "no_nested_random_variables"]) +def has_nested_random_variables(request): + return request.param == "nested_random_variables" + + +@pytest.fixture(scope="module") +def model_to_vectorize(has_nested_random_variables): + with pm.Model() as model: + if not has_nested_random_variables: + x_parent = 0.0 + else: + x_parent = pm.Normal("x_parent") + x = pm.Normal("x", mu=x_parent) + d = pm.Data("d", np.array([1, 2, 3])) + obs = pm.Data("obs", np.ones_like(d.get_value())) + y = pm.Deterministic("y", x * d) + z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs) + pm.Deterministic("z_downstream", z * 2) + + with model: + idata = pm.sample_prior_predictive(100) + idata.add_groups({"posterior": idata.prior}) + return freeze_dims_and_data(model), idata + + +@pytest.fixture(params=["rv_from_posterior", "resample_rv"]) +def input_rv_names(request, has_nested_random_variables): + if request.param == "rv_from_posterior": + if has_nested_random_variables: + return ["x_parent", "x"] + else: + return ["x"] + else: + return [] + + +def test_vectorize_over_posterior( + variable_to_vectorize, + input_rv_names, + allow_rvs_in_graph, + model_to_vectorize, +): + model, idata = model_to_vectorize + + if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize): + with pytest.raises( + RuntimeError, + match="The following random variables found in the extracted graph", + ): + vectorize_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + else: + vectorized = vectorize_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + assert all( + vectorized_var is not model[name] + for vectorized_var, name in zip(vectorized, variable_to_vectorize) + ) + assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized) + assert all(variable_depends_on(vectorized_var, model["d"]) for vectorized_var in vectorized) + if len(vectorized) == 2: + assert variable_depends_on( + vectorized[variable_to_vectorize.index("z_downstream")], + vectorized[variable_to_vectorize.index("z")], + ) + if len(input_rv_names) > 0: + for input_rv_name in input_rv_names: + if input_rv_name == "x_parent": + assert len(get_var_by_name(vectorized, input_rv_name)) == 0 + else: + [vectorized_rv] = get_var_by_name(vectorized, input_rv_name) + rv_posterior = idata.posterior[input_rv_name].data + assert isinstance(vectorized_rv, TensorConstant) + assert np.all(vectorized_rv.value == rv_posterior) + else: + batch_shape = ( + len(idata.posterior.coords["chain"]), + len(idata.posterior.coords["draw"]), + ) + original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize]) + expected_rv_shapes = {(*batch_shape, *rv.type.shape) for rv in original_rvs} + rvs = rvs_in_graph(vectorized) + assert {rv.type.shape for rv in rvs} == expected_rv_shapes + + +def test_vectorize_over_posterior_matches_sample(): + rng = np.random.default_rng(1234) + with pm.Model() as model: + x = pm.Normal("x") + sigma = 0.1 + obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10)) + det = pm.Deterministic("det", obs + 1) + + chains = 2 + draws = 100 + x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws)) + with model: + posterior = xr.Dataset( + { + "x": xr.DataArray( + x_posterior, + dims=("chain", "draw"), + coords={"chain": np.arange(chains), "draw": np.arange(draws)}, + ) + } + ) + idata = InferenceData(posterior=posterior) + with model: + pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234) + vectorized = vectorize_over_posterior( + outputs=[obs, det], + posterior=posterior, + input_rvs=[x], + allow_rvs_in_graph=True, + ) + [vect_obs, vect_det] = compile(inputs=[], outputs=vectorized, random_seed=1234)() + assert pp.posterior_predictive["obs"].shape == vect_obs.shape + assert pp.posterior_predictive["det"].shape == vect_det.shape + np.testing.assert_allclose(vect_obs + 1, vect_det) + np.testing.assert_allclose( + pp.posterior_predictive["obs"].mean(dim=("chain", "draw")), + vect_obs.mean(axis=(0, 1)), + atol=0.6 / np.sqrt(10000), + ) + assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 2d32277061..ed3fc09dd0 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -15,8 +15,9 @@ import numpy as np import numpy.testing as npt import pytest +import xarray as xr -from pymc import Data, Model, Normal, sample +from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) @@ -86,3 +87,55 @@ def test_step_args(): ) npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + + +@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) +def test_sample_var_names(nuts_sampler): + seed = 1234 + kwargs = { + "chains": 1, + "tune": 100, + "draws": 100, + "random_seed": seed, + "progressbar": False, + "compute_convergence_checks": False, + } + + # Generate data + rng = np.random.default_rng(seed) + + group = rng.choice(list("ABCD"), size=100) + x = rng.normal(size=100) + y = rng.normal(size=100) + + group_values, group_idx = np.unique(group, return_inverse=True) + + coords = {"group": group_values} + + # Create model + with Model(coords=coords) as model: + b_group = Normal("b_group", dims="group") + b_x = Normal("b_x") + mu = Deterministic("mu", b_group[group_idx] + b_x * x) + sigma = HalfNormal("sigma") + Normal("y", mu=mu, sigma=sigma, observed=y) + + free_RVs = [var.name for var in model.free_RVs] + + with model: + # Sample with and without var_names, but always with the same seed + idata_1 = sample(nuts_sampler=nuts_sampler, **kwargs) + # Remove the last free RV from the sampling + idata_2 = sample(nuts_sampler=nuts_sampler, var_names=free_RVs[:-1], **kwargs) + + assert "mu" in idata_1.posterior + assert "mu" not in idata_2.posterior + + assert free_RVs[-1] in idata_1.posterior + assert free_RVs[-1] not in idata_2.posterior + + for var in free_RVs[:-1]: + assert var in idata_1.posterior + assert var in idata_2.posterior + + xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var]) diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index 1f7ba44791..52d5c5048c 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -16,19 +16,27 @@ import arviz import numpy as np +import pytest from pymc.stats import convergence -def test_warn_divergences(): +@pytest.mark.parametrize( + "diverging, expected_phrase", + [ + pytest.param([1, 0, 1, 0], "were 2 divergences after tuning", id="plural"), + pytest.param([1, 0, 0, 0], "was 1 divergence after tuning", id="singular"), + ], +) +def test_warn_divergences(diverging, expected_phrase): idata = arviz.from_dict( sample_stats={ - "diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool), + "diverging": np.array([diverging, [0, 0, 0, 0]]).astype(bool), } ) warns = convergence.warn_divergences(idata) assert len(warns) == 1 - assert "2 divergences after tuning" in warns[0].message + assert expected_phrase in warns[0].message def test_warn_treedepth(): diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 432418a33a..8d497f3011 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -148,6 +148,7 @@ def test_sampler_stats(self): expected_stat_names = { "depth", "diverging", + "divergences", "energy", "energy_error", "model_logp", diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index bda39bd1f6..e13ed5279c 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -294,15 +294,6 @@ class MyNormalDistribution(pm.Normal): assert np.isclose(res["x"], np.pi) - def test_future_warning_moment(self): - with pm.Model() as m: - pm.Normal("x", initval="moment") - with pytest.warns( - FutureWarning, - match="The 'moment' strategy is deprecated. Use 'support_point' instead.", - ): - ip = m.initial_point(random_seed=42) - def test_pickling_issue_5090(): with pm.Model() as model: diff --git a/tests/test_math.py b/tests/test_math.py index 3f811fc2b7..eeee1f164e 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -15,7 +15,6 @@ import warnings import numpy as np -import numpy.testing as npt import pytensor import pytensor.tensor as pt import pytest @@ -29,10 +28,8 @@ kron_solve_lower, kronecker, log1mexp, - log1mexp_numpy, logdet, logdiffexp, - logdiffexp_numpy, probit, ) from pymc.pytensorf import floatX @@ -126,70 +123,22 @@ def test_probit(): np.testing.assert_allclose(invprobit(probit(p)).eval(), p, atol=1e-5) -def test_log1mexp(): - vals = np.array([-1, 0, 1e-20, 1e-4, 10, 100, 1e20]) - vals_ = vals.copy() - # import mpmath - # mpmath.mp.dps = 1000 - # [float(mpmath.log(1 - mpmath.exp(-x))) for x in vals] - expected = np.array( - [ - np.nan, - -np.inf, - -46.051701859880914, - -9.210390371559516, - -4.540096037048921e-05, - -3.720075976020836e-44, - 0.0, - ] - ) - actual = pt.log1mexp(-vals).eval() - npt.assert_allclose(actual, expected) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) - warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning) - with pytest.warns(FutureWarning, match="deprecated"): - actual_ = log1mexp_numpy(-vals, negative_input=True) - npt.assert_allclose(actual_, expected) - # Check that input was not changed in place - npt.assert_allclose(vals, vals_) - - -@pytest.mark.filterwarnings("error") -def test_log1mexp_numpy_no_warning(): - """Assert RuntimeWarning is not raised for very small numbers""" - with pytest.warns(FutureWarning, match="deprecated"): - log1mexp_numpy(-1e-25, negative_input=True) - - -def test_log1mexp_numpy_integer_input(): - with pytest.warns(FutureWarning, match="deprecated"): - assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval()) - - @pytest.mark.filterwarnings("error") def test_log1mexp_deprecation_warnings(): - with pytest.warns(FutureWarning, match="deprecated"): - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp_numpy will expect a negative input", - ): - res_pos = log1mexp_numpy(2) - - res_neg = log1mexp_numpy(-2, negative_input=True) + with pytest.raises( + ValueError, + match="log1mexp with negative_input=False is no longer supported", + ): + log1mexp(2, negative_input=False).eval() - with pytest.warns( - FutureWarning, - match="pymc.math.log1mexp will expect a negative input", - ): - res_pos_at = log1mexp(2).eval() + with pytest.warns(FutureWarning): + res_1 = log1mexp(-2, negative_input=True).eval() - res_neg_at = log1mexp(-2, negative_input=True).eval() + res_2 = log1mexp(-2).eval() + res_ref = pt.log1mexp(-2).eval() - assert np.isclose(res_pos, res_neg) - assert np.isclose(res_pos_at, res_neg) - assert np.isclose(res_neg_at, res_neg) + assert np.isclose(res_ref, res_1) + assert np.isclose(res_ref, res_2) def test_logdiffexp(): @@ -197,8 +146,6 @@ def test_logdiffexp(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - with pytest.warns(FutureWarning, match="deprecated"): - assert np.allclose(logdiffexp_numpy(a, b), 0) assert np.allclose(logdiffexp(a, b).eval(), 0) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 70c10d4106..94b797f8e7 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings +from textwrap import dedent + import numpy as np import pytensor import pytensor.tensor as pt @@ -31,6 +33,7 @@ NodeType, Plate, model_to_graphviz, + model_to_mermaid, model_to_networkx, ) @@ -513,14 +516,39 @@ def test_model_graph_with_intermediate_named_variables(): with pm.Model() as m1: a = pm.Normal("a", 0, 1, shape=3) pm.Normal("b", a.mean(axis=-1), 1) - assert dict(ModelGraph(m1).make_compute_graph()) == {"a": set(), "b": {"a"}} + assert ModelGraph(m1).make_compute_graph() == {"a": set(), "b": {"a"}} with pm.Model() as m2: a = pm.Normal("a", 0, 1) b = a + 1 b.name = "b" pm.Normal("c", b, 1) - assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}} + assert ModelGraph(m2).make_compute_graph() == {"a": set(), "c": {"a"}} + + # Regression test for https://github.com/pymc-devs/pymc/issues/7397 + with pm.Model() as m3: + data = pt.as_tensor_variable( + np.ones((5, 3)), + name="C", + ) + # C has the same name as `data` variable + # This used to be wrongly picked up as a dependency + C = pm.Deterministic("C", data) + # D depends on a variable called `C` but this is not really one in the model + D = pm.Deterministic("D", data) + # This actually depends on the model variable `C` + E = pm.Deterministic("E", C) + assert ModelGraph(m3).make_compute_graph() == {"C": set(), "D": set(), "E": {"C"}} + + +def test_model_graph_complex_observed_dependency(): + with pm.Model() as model: + x = pm.Data("x", [0]) + y = pm.Data("y", [0]) + observed = pt.exp(x) + pt.log(y) + pm.Normal("obs", mu=0, observed=observed) + + assert ModelGraph(model).make_compute_graph() == {"obs": set(), "x": {"obs"}, "y": {"obs"}} @pytest.fixture @@ -629,3 +657,40 @@ def test_scalars_dim_info() -> None: ] assert graph.edges() == [] + + +def test_model_to_mermaid(simple_model): + expected_mermaid_string = dedent(""" + graph TD + %% Nodes: + a([a ~ Normal]) + a@{ shape: rounded } + b([b ~ Normal]) + b@{ shape: rounded } + c([c ~ Normal]) + c@{ shape: rounded } + + %% Edges: + a --> b + b --> c + + %% Plates: + """) + assert model_to_mermaid(simple_model) == expected_mermaid_string.strip() + + +def test_model_to_mermaid_with_variable_with_space(): + with pm.Model() as variable_with_space: + pm.Normal("plant growth") + + expected_mermaid_string = dedent(""" + graph TD + %% Nodes: + plant_growth([plant growth ~ Normal]) + plant_growth@{ shape: rounded } + + %% Edges: + + %% Plates: + """) + assert model_to_mermaid(variable_with_space) == expected_mermaid_string.strip() diff --git a/tests/test_printing.py b/tests/test_printing.py index 917a2d1ee2..879b25b96b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -141,11 +141,11 @@ def setup_class(self): r"beta ~ Normal(0, 10)", r"Z ~ MultivariateNormal(f(), f())", r"nb_with_p_n ~ NegativeBinomial(10, nbp)", - r"zip ~ MarginalMixture(f(), DiracDelta(0), Poisson(5))", + r"zip ~ Mixture(f(), DiracDelta(0), Poisson(5))", r"w ~ Dirichlet()", ( - r"nested_mix ~ MarginalMixture(w, " - r"MarginalMixture(f(), DiracDelta(0), Poisson(5)), " + r"nested_mix ~ Mixture(w, " + r"Mixture(f(), DiracDelta(0), Poisson(5)), " r"Censored(Bernoulli(0.5), -1, 1))" ), r"Y_obs ~ Normal(mu, sigma)", @@ -159,9 +159,9 @@ def setup_class(self): r"beta ~ Normal", r"Z ~ MultivariateNormal", r"nb_with_p_n ~ NegativeBinomial", - r"zip ~ MarginalMixture", + r"zip ~ Mixture", r"w ~ Dirichlet", - r"nested_mix ~ MarginalMixture", + r"nested_mix ~ Mixture", r"Y_obs ~ Normal", r"pot ~ Potential", r"pred ~ Deterministic", @@ -173,11 +173,11 @@ def setup_class(self): r"$\text{beta} \sim \operatorname{Normal}(0,~10)$", r"$\text{Z} \sim \operatorname{MultivariateNormal}(f(),~f())$", r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$", - r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))$", + r"$\text{zip} \sim \operatorname{Mixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))$", r"$\text{w} \sim \operatorname{Dirichlet}(\text{})$", ( - r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}(\text{w}," - r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))," + r"$\text{nested\_mix} \sim \operatorname{Mixture}(\text{w}," + r"~\operatorname{Mixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))," r"~\operatorname{Censored}(\operatorname{Bernoulli}(0.5),~-1,~1))$" ), r"$\text{Y\_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$", @@ -191,9 +191,9 @@ def setup_class(self): r"$\text{beta} \sim \operatorname{Normal}$", r"$\text{Z} \sim \operatorname{MultivariateNormal}$", r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}$", - r"$\text{zip} \sim \operatorname{MarginalMixture}$", + r"$\text{zip} \sim \operatorname{Mixture}$", r"$\text{w} \sim \operatorname{Dirichlet}$", - r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}$", + r"$\text{nested\_mix} \sim \operatorname{Mixture}$", r"$\text{Y\_obs} \sim \operatorname{Normal}$", r"$\text{pot} \sim \operatorname{Potential}$", r"$\text{pred} \sim \operatorname{Deterministic}", @@ -276,7 +276,7 @@ def test_model_latex_repr_mixture_model(): "\\begin{array}{rcl}", "\\text{w} &\\sim & " "\\operatorname{Dirichlet}(\\text{})\\\\\\text{mix} &\\sim & " - "\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{Normal}(0,~5),~\\operatorname{StudentT}(7,~0,~1))", + "\\operatorname{Mixture}(\\text{w},~\\operatorname{Normal}(0,~5),~\\operatorname{StudentT}(7,~0,~1))", "\\end{array}", "$$", ] diff --git a/tests/test_progress_bar.py b/tests/test_progress_bar.py new file mode 100644 index 0000000000..6687db1ae8 --- /dev/null +++ b/tests/test_progress_bar.py @@ -0,0 +1,46 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pymc as pm + + +def test_progressbar_nested_compound(): + # Regression test for https://github.com/pymc-devs/pymc/issues/7721 + + with pm.Model(): + a = pm.Poisson("a", mu=10) + b = pm.Binomial("b", n=a, p=0.8) + c = pm.Poisson("c", mu=11) + d = pm.Dirichlet("d", a=[c, b]) + + step = pm.CompoundStep( + [ + pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]), + pm.NUTS([d]), + ] + ) + + kwargs = { + "draws": 10, + "tune": 10, + "chains": 2, + "compute_convergence_checks": False, + "step": step, + } + + # We don't parametrize to avoid recompiling the model functions + for cores in (1, 2): + pm.sample(**kwargs, cores=cores, progressbar=True) # default is split+stats + pm.sample(**kwargs, cores=cores, progressbar="combined") + pm.sample(**kwargs, cores=cores, progressbar="split") + pm.sample(**kwargs, cores=cores, progressbar=False) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 34360397a3..f7efd5f6d4 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -25,7 +25,6 @@ from pytensor.compile import UnusedInputError from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable, equal_computations -from pytensor.tensor.random.basic import normal, uniform from pytensor.tensor.subtensor import AdvancedIncSubtensor import pymc as pm @@ -36,6 +35,7 @@ from pymc.exceptions import NotConstantValueError from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( + PointFunc, collect_default_updates, compile, constant_fold, @@ -46,7 +46,6 @@ replace_rng_nodes, replace_vars_in_graphs, reseed_rngs, - walk_model, ) from pymc.vartypes import int_types @@ -61,7 +60,7 @@ ) def test_pd_dataframe_as_tensor_variable(np_array: np.ndarray) -> None: df = pd.DataFrame(np_array) - np.testing.assert_array_equal(x=pt.as_tensor_variable(x=df).eval(), y=np_array) + np.testing.assert_array_equal(pt.as_tensor_variable(df).eval(), np_array) @pytest.mark.parametrize( @@ -70,7 +69,7 @@ def test_pd_dataframe_as_tensor_variable(np_array: np.ndarray) -> None: ) def test_pd_series_as_tensor_variable(np_array: np.ndarray) -> None: df = pd.Series(np_array) - np.testing.assert_array_equal(x=pt.as_tensor_variable(x=df).eval(), y=np_array) + np.testing.assert_array_equal(pt.as_tensor_variable(df).eval(), np_array) def test_pd_as_tensor_variable_multiindex() -> None: @@ -81,7 +80,7 @@ def test_pd_as_tensor_variable_multiindex() -> None: df = pd.DataFrame({"A": [12.0, 80.0, 30.0, 20.0], "B": [120.0, 700.0, 30.0, 20.0]}, index=index) np_array = np.array([[12.0, 80.0, 30.0, 20.0], [120.0, 700.0, 30.0, 20.0]]).T assert isinstance(df.index, pd.MultiIndex) - np.testing.assert_array_equal(x=pt.as_tensor_variable(x=df).eval(), y=np_array) + np.testing.assert_array_equal(pt.as_tensor_variable(df).eval(), np_array) class TestBroadcasting: @@ -283,42 +282,7 @@ def test_pandas_to_array_pandas_index(): np.testing.assert_array_equal(result, expected) -def test_walk_model(): - a = pt.vector("a") - b = uniform(0.0, a, name="b") - c = pt.log(b) - c.name = "c" - d = pt.vector("d") - e = normal(c, d, name="e") - - test_graph = pt.exp(e + 1) - - with pytest.warns(FutureWarning): - res = list(walk_model((test_graph,))) - assert a in res - assert b in res - assert c in res - assert d in res - assert e in res - - with pytest.warns(FutureWarning): - res = list(walk_model((test_graph,), stop_at_vars={c})) - assert a not in res - assert b not in res - assert c in res - assert d in res - assert e in res - - with pytest.warns(FutureWarning): - res = list(walk_model((test_graph,), stop_at_vars={b})) - assert a not in res - assert b in res - assert c in res - assert d in res - assert e in res - - -class TestCompilePyMC: +class TestCompile: def test_check_bounds_flag(self): """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" logp = pt.ones(3) @@ -780,3 +744,17 @@ def test_hessian_sign_change_warning(func): res_neg = func(f, vars=[x]) res = func(f, vars=[x], negate_output=False) assert equal_computations([res_neg], [-res]) + + +def test_point_func(): + x, y = pt.vectors("x", "y") + outs = x * 2 + y**2 + f = compile([x, y], outs) + + point_f = PointFunc(f) + np.testing.assert_allclose(point_f({"y": [3], "x": [2]}), [4 + 9]) + + # Check we can access other methods of the wrapped pytensor function + dprint_res = point_f.dprint(file="str") + expected_dprint_res = point_f.f.dprint(file="str") + assert dprint_res == expected_dprint_res