diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index a1aa85c7e..450b46e30 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,7 +3,7 @@ channels: - conda-forge - nodefaults dependencies: -- pymc>=5.20 +- pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 - dask diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 1d1eb7745..d2a3e8934 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,7 +10,7 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 -- pymc>=5.20 +- pymc>=5.21 - pip: - blackjax - scikit-learn diff --git a/pymc_extras/inference/pathfinder/importance_sampling.py b/pymc_extras/inference/pathfinder/importance_sampling.py index 3b4a0ee78..8d04c077a 100644 --- a/pymc_extras/inference/pathfinder/importance_sampling.py +++ b/pymc_extras/inference/pathfinder/importance_sampling.py @@ -20,7 +20,7 @@ class ImportanceSamplingResult: samples: NDArray pareto_k: float | None = None warnings: list[str] = field(default_factory=list) - method: str = "none" + method: str = "psis" def importance_sampling( @@ -28,7 +28,7 @@ def importance_sampling( logP: NDArray, logQ: NDArray, num_draws: int, - method: Literal["psis", "psir", "identity", "none"] | None, + method: Literal["psis", "psir", "identity"] | None, random_seed: int | None = None, ) -> ImportanceSamplingResult: """Pareto Smoothed Importance Resampling (PSIR) @@ -44,8 +44,15 @@ def importance_sampling( log probability values of proposal distribution, shape (L, M) num_draws : int number of draws to return where num_draws <= samples.shape[0] - method : str, optional - importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths. + method : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + Options are: + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). random_seed : int | None Returns @@ -71,11 +78,11 @@ def importance_sampling( warnings = [] num_paths, _, N = samples.shape - if method == "none": + if method is None: warnings.append( "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability." ) - return ImportanceSamplingResult(samples=samples, warnings=warnings) + return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method) else: samples = samples.reshape(-1, N) logP = logP.ravel() @@ -91,17 +98,16 @@ def importance_sampling( _warnings.filterwarnings( "ignore", category=RuntimeWarning, message="overflow encountered in exp" ) - if method == "psis": - replace = False - logiw, pareto_k = az.psislw(logiw) - elif method == "psir": - replace = True - logiw, pareto_k = az.psislw(logiw) - elif method == "identity": - replace = False - pareto_k = None - else: - raise ValueError(f"Invalid importance sampling method: {method}") + match method: + case "psis": + replace = False + logiw, pareto_k = az.psislw(logiw) + case "psir": + replace = True + logiw, pareto_k = az.psislw(logiw) + case "identity": + replace = False + pareto_k = None # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. # Pareto k may not be a good diagnostic for Pathfinder. diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 8f79d9665..dfe5fc6a0 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -60,6 +60,7 @@ from pytensor.tensor import TensorConstant, TensorVariable from rich.console import Console, Group from rich.padding import Padding +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.table import Table from rich.text import Text @@ -155,7 +156,7 @@ def convert_flat_trace_to_idata( postprocessing_backend: Literal["cpu", "gpu"] = "cpu", inference_backend: Literal["pymc", "blackjax"] = "pymc", model: Model | None = None, - importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis", + importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", ) -> az.InferenceData: """convert flattened samples to arviz InferenceData format. @@ -180,7 +181,7 @@ def convert_flat_trace_to_idata( arviz inference data object """ - if importance_sampling == "none": + if importance_sampling is None: # samples.ndim == 3 in this case, otherwise ndim == 2 num_paths, num_pdraws, N = samples.shape samples = samples.reshape(-1, N) @@ -219,7 +220,7 @@ def convert_flat_trace_to_idata( fn.trust_input = True result = fn(*list(trace.values())) - if importance_sampling == "none": + if importance_sampling is None: result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result] elif inference_backend == "blackjax": @@ -1188,7 +1189,7 @@ class MultiPathfinderResult: elbo_argmax: NDArray | None = None lbfgs_status: Counter = field(default_factory=Counter) path_status: Counter = field(default_factory=Counter) - importance_sampling: str = "none" + importance_sampling: str | None = "psis" warnings: list[str] = field(default_factory=list) pareto_k: float | None = None @@ -1257,7 +1258,7 @@ def with_warnings(self, warnings: list[str]) -> Self: def with_importance_sampling( self, num_draws: int, - method: Literal["psis", "psir", "identity", "none"] | None, + method: Literal["psis", "psir", "identity"] | None, random_seed: int | None = None, ) -> Self: """perform importance sampling""" @@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]: path_status_message = { PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.", - PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", + PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.", PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", } @@ -1423,7 +1424,7 @@ def multipath_pathfinder( num_elbo_draws: int, jitter: float, epsilon: float, - importance_sampling: Literal["psis", "psir", "identity", "none"] | None, + importance_sampling: Literal["psis", "psir", "identity"] | None, progressbar: bool, concurrent: Literal["thread", "process"] | None, random_seed: RandomSeed, @@ -1459,8 +1460,14 @@ def multipath_pathfinder( Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value. epsilon: float value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8). - importance_sampling : str, optional - importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N). + importance_sampling : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). progressbar : bool, optional Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time. random_seed : RandomSeed, optional @@ -1482,12 +1489,6 @@ def multipath_pathfinder( The result containing samples and other information from the Multi-Path Pathfinder algorithm. """ - valid_importance_sampling = ["psis", "psir", "identity", "none", None] - if importance_sampling is None: - importance_sampling = "none" - if importance_sampling.lower() not in valid_importance_sampling: - raise ValueError(f"Invalid importance sampling method: {importance_sampling}") - *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1) pathfinder_config = PathfinderConfig( @@ -1521,12 +1522,20 @@ def multipath_pathfinder( results = [] compute_start = time.time() try: - with CustomProgress( + desc = f"Paths Complete: {{path_idx}}/{num_paths}" + progress = CustomProgress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), console=Console(theme=default_progress_theme), disable=not progressbar, - ) as progress: - task = progress.add_task("Fitting", total=num_paths) - for result in generator: + ) + with progress: + task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths) + for path_idx, result in enumerate(generator, start=1): try: if isinstance(result, Exception): raise result @@ -1552,7 +1561,14 @@ def multipath_pathfinder( lbfgs_status=LBFGSStatus.LBFGS_FAILED, ) ) - progress.update(task, advance=1) + finally: + # TODO: display LBFGS and Path Status in real time + progress.update( + task, + description=desc.format(path_idx=path_idx), + completed=path_idx, + refresh=True, + ) except (KeyboardInterrupt, StopIteration) as e: # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData. if isinstance(e, StopIteration): @@ -1606,7 +1622,7 @@ def fit_pathfinder( num_elbo_draws: int = 10, # K jitter: float = 2.0, epsilon: float = 1e-8, - importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis", + importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", progressbar: bool = True, concurrent: Literal["thread", "process"] | None = None, random_seed: RandomSeed | None = None, @@ -1646,8 +1662,15 @@ def fit_pathfinder( Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value. epsilon: float value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8). - importance_sampling : str, optional - importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N). + importance_sampling : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + Options are: + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). progressbar : bool, optional Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time. random_seed : RandomSeed, optional @@ -1674,6 +1697,15 @@ def fit_pathfinder( """ model = modelcontext(model) + + valid_importance_sampling = {"psis", "psir", "identity", None} + + if importance_sampling is not None: + importance_sampling = importance_sampling.lower() + + if importance_sampling not in valid_importance_sampling: + raise ValueError(f"Invalid importance sampling method: {importance_sampling}") + N = DictToArrayBijection.map(model.initial_point()).data.shape[0] if maxcor is None: diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 2590dd53a..778376720 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -28,7 +28,6 @@ ) from pymc_extras.statespace.filters.distributions import ( LinearGaussianStateSpace, - MvNormalSVD, SequenceMvNormal, ) from pymc_extras.statespace.filters.utilities import stabilize @@ -707,7 +706,7 @@ def _insert_random_variables(self): with pymc_model: for param_name in self.param_names: param = getattr(pymc_model, param_name, None) - if param: + if param is not None: found_params.append(param.name) missing_params = list(set(self.param_names) - set(found_params)) @@ -746,7 +745,7 @@ def _insert_data_variables(self): with pymc_model: for data_name in data_names: data = getattr(pymc_model, data_name, None) - if data: + if data is not None: found_data.append(data.name) missing_data = list(set(data_names) - set(found_data)) @@ -2233,7 +2232,9 @@ def impulse_response_function( if shock_trajectory is None: shock_trajectory = pt.zeros((n_steps, self.k_posdef)) if Q is not None: - init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM]) + init_shock = pm.MvNormal( + "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd" + ) else: init_shock = pm.Deterministic( "initial_shock", diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d3b70c847..1e4f2b153 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -6,11 +6,9 @@ from pymc import intX from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous, SymbolicRandomVariable -from pymc.distributions.multivariate import MvNormal from pymc.distributions.shape_utils import get_support_shape_1d from pymc.logprob.abstract import _logprob from pytensor.graph.basic import Node -from pytensor.tensor.random.basic import MvNormalRV floatX = pytensor.config.floatX COV_ZERO_TOL = 0 @@ -49,44 +47,6 @@ def make_signature(sequence_names): return f"{signature},[rng]->[rng],({time},{state_and_obs})" -class MvNormalSVDRV(MvNormalRV): - name = "multivariate_normal" - signature = "(n),(n,n)->(n)" - dtype = "floatX" - _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") - - -class MvNormalSVD(MvNormal): - """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd". - - A JAX MvNormal robust to low-rank covariance matrices - """ - - rv_op = MvNormalSVDRV() - - -try: - import jax.random - - from pytensor.link.jax.dispatch.random import jax_sample_fn - - @jax_sample_fn.register(MvNormalSVDRV) - def jax_sample_fn_mvnormal_svd(op, node): - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = jax.random.multivariate_normal( - sampling_key, *parameters, shape=size, dtype=dtype, method="svd" - ) - rng["jax_state"] = rng_key - return (rng, sample) - - return sample_fn - -except ImportError: - pass - - class LinearGaussianStateSpaceRV(SymbolicRandomVariable): default_output = 1 _print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}") @@ -244,8 +204,12 @@ def step_fn(*args): k = T.shape[0] a = state[:k] - middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs - next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs + middle_rng, a_innovation = pm.MvNormal.dist( + mu=0, cov=Q, rng=rng, method="svd" + ).owner.outputs + next_rng, y_innovation = pm.MvNormal.dist( + mu=0, cov=H, rng=middle_rng, method="svd" + ).owner.outputs a_mu = c + T @ a a_next = a_mu + R @ a_innovation @@ -260,8 +224,8 @@ def step_fn(*args): Z_init = Z_ if Z_ in non_sequences else Z_[0] H_init = H_ if H_ in non_sequences else H_[0] - init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng) - init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng) + init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd") + init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd") init_dist_ = pt.concatenate([init_x_, init_y_], axis=0) @@ -421,7 +385,7 @@ def rv_op(cls, mus, covs, logp, size=None): rng = pytensor.shared(np.random.default_rng()) def step(mu, cov, rng): - new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs + new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs return mvn, {rng: new_rng} mvn_seq, updates = pytensor.scan( diff --git a/pymc_extras/version.txt b/pymc_extras/version.txt index 717903969..abd410582 100644 --- a/pymc_extras/version.txt +++ b/pymc_extras/version.txt @@ -1 +1 @@ -0.2.3 +0.2.4 diff --git a/pyproject.toml b/pyproject.toml index 17187a524..df7752cb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ filterwarnings =[ # Warning coming from blackjax 'ignore:jax\.tree_map is deprecated:DeprecationWarning', + + # PyMC uses numpy.core functions, which emits an warning as of numpy>2.0 + 'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning', ] [tool.coverage.report] diff --git a/requirements.txt b/requirements.txt index a4f00ee21..49c7d88af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -pymc>=5.20 +pymc>=5.21.1 scikit-learn better-optimize diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 1d5b2a9ec..af9213ff4 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -44,8 +44,8 @@ def reference_idata(): with model: idata = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend="pymc", ) @@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata): with model: idata = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend=inference_backend, ) else: idata = reference_idata - np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6) - np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5) + np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95) + np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35) assert idata.posterior["mu"].shape == (1, 1000) assert idata.posterior["tau"].shape == (1, 1000) @@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent): with model: idata_conc = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend="pymc", concurrent=concurrent, @@ -108,7 +108,7 @@ def test_seed(reference_idata): with model: idata_41 = pmx.fit( method="pathfinder", - num_paths=50, + num_paths=4, jitter=10.0, random_seed=41, inference_backend="pymc", @@ -116,7 +116,7 @@ def test_seed(reference_idata): idata_123 = pmx.fit( method="pathfinder", - num_paths=50, + num_paths=4, jitter=10.0, random_seed=123, inference_backend="pymc", @@ -171,3 +171,33 @@ def test_bfgs_sample(): assert gamma.eval().shape == (L, 2 * J, 2 * J) assert phi.eval().shape == (L, num_samples, N) assert logq.eval().shape == (L, num_samples) + + +@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None]) +def test_pathfinder_importance_sampling(importance_sampling): + model = eight_schools_model() + + num_paths = 4 + num_draws_per_path = 300 + num_draws = 750 + + with model: + idata = pmx.fit( + method="pathfinder", + num_paths=num_paths, + num_draws_per_path=num_draws_per_path, + num_draws=num_draws, + maxiter=5, + random_seed=41, + inference_backend="pymc", + importance_sampling=importance_sampling, + ) + + if importance_sampling is None: + assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path) + assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path) + assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8) + else: + assert idata.posterior["mu"].shape == (1, num_draws) + assert idata.posterior["tau"].shape == (1, num_draws) + assert idata.posterior["theta"].shape == (1, num_draws, 8)