From 6b2aa67bff903faef2518229bd039e673504804a Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Fri, 6 Jun 2025 00:34:48 +0200 Subject: [PATCH 1/6] Drop Python 3.10 support (#497) * Drop Python 3.10 support * Test on 3.11 and run `pip check` * Update pypi workflow to bump Python 3.10 -> 3.11 * Use Python 3.11 with readthedocs --- .github/workflows/pypi.yml | 2 +- .github/workflows/test.yml | 3 ++- .readthedocs.yaml | 2 +- pyproject.toml | 3 +-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 3b8a57e61..3c5c47d4e 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -17,7 +17,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.11" - name: Build the sdist and the wheel run: | pip install build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8fc2a8cf5..aa8756aa8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,7 +25,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.11"] test-subset: - tests/model - tests/statespace/core/test_statespace.py @@ -55,6 +55,7 @@ jobs: run: | pip install -e ".[dev]" python --version + pip check - name: Run tests run: | python -m pytest --color=yes -vv --cov=pymc_extras --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 304c31290..2b724c6f1 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,7 +5,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.10" + python: "3.11" python: install: diff --git a/pyproject.toml b/pyproject.toml index 982699723..2f26932be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ classifiers = [ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -23,7 +22,7 @@ classifiers = [ "Operating System :: OS Independent", ] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" keywords = [ "probability", "machine learning", From 009b5acee3463efb3e5a30ac9f06b48151d29f1b Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 9 Jun 2025 18:34:42 +0200 Subject: [PATCH 2/6] Support marginalising through a MinibatchRandomVariable --- pymc_extras/model/marginal/graph_analysis.py | 4 ++++ tests/model/marginal/test_graph_analysis.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index 422177dd4..6a7a7f874 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -5,6 +5,7 @@ from pymc import SymbolicRandomVariable from pymc.model.fgraph import ModelVar +from pymc.variational.minibatch_rv import MinibatchRandomVariable from pytensor.graph import Variable, ancestors from pytensor.graph.basic import io_toposort from pytensor.tensor import TensorType, TensorVariable @@ -313,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) var_dims[node.outputs[0]] = output_dims + elif isinstance(node.op, MinibatchRandomVariable): + var_dims[node.outputs[0]] = inputs_dims[0] + else: raise NotImplementedError(f"Marginalization through operation {node} not supported.") diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py index 57affd0ee..e835a4d03 100644 --- a/tests/model/marginal/test_graph_analysis.py +++ b/tests/model/marginal/test_graph_analysis.py @@ -2,6 +2,7 @@ import pytest from pymc.distributions import CustomDist +from pymc.variational.minibatch_rv import create_minibatch_rv from pytensor.tensor.type_other import NoneTypeT from pymc_extras.model.marginal.graph_analysis import ( @@ -160,6 +161,13 @@ def test_random_variable(self): with pytest.raises(ValueError, match="Use of known dimensions"): subgraph_batch_dim_connection(inp, [invalid_out]) + def test_minibatched_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + out1 = pt.random.normal(loc=inp) + out2 = create_minibatch_rv(out1, total_size=(10, 10, 10)) + [dims1] = subgraph_batch_dim_connection(inp, [out2]) + assert dims1 == (0, 1, 2) + def test_symbolic_random_variable(self): inp = pt.tensor(shape=(4, 3, 2)) From 88f0e07944e3802cb4378f3e35c90dab17852215 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+williambdean@users.noreply.github.com> Date: Tue, 10 Jun 2025 15:55:08 -0400 Subject: [PATCH 3/6] change var to variable (#512) --- pymc_extras/prior.py | 8 ++++---- tests/test_prior.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 12eb36a0d..8d390aaea 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -278,7 +278,7 @@ def create_variable(self, name: str) -> pt.TensorVariable: def sample_prior( factory: VariableFactory, coords=None, - name: str = "var", + name: str = "variable", wrap: bool = False, **sample_prior_predictive_kwargs, ) -> xr.Dataset: @@ -292,7 +292,7 @@ def sample_prior( The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional - The name of the variable, by default "var". + The name of the variable, by default "variable". wrap : bool, optional Whether to wrap the variable in a `pm.Deterministic` node, by default False. sample_prior_predictive_kwargs : dict @@ -900,7 +900,7 @@ def __eq__(self, other) -> bool: def sample_prior( self, coords=None, - name: str = "var", + name: str = "variable", **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior distribution for the variable. @@ -911,7 +911,7 @@ def sample_prior( The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional - The name of the variable, by default "var". + The name of the variable, by default "variable". sample_prior_predictive_kwargs : dict Additional arguments to pass to `pm.sample_prior_predictive`. diff --git a/tests/test_prior.py b/tests/test_prior.py index 70729b9f9..a7b630f83 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -616,7 +616,10 @@ def test_custom_transform() -> None: prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() - np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) + np.testing.assert_array_equal( + df_prior.variable.to_numpy(), + df_prior.variable_raw.to_numpy() ** 2, + ) def test_custom_transform_comes_first() -> None: @@ -627,7 +630,10 @@ def test_custom_transform_comes_first() -> None: prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() - np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) + np.testing.assert_array_equal( + df_prior.variable.to_numpy(), + 2 * df_prior.variable_raw.to_numpy(), + ) clear_custom_transforms() @@ -686,7 +692,7 @@ def test_sample_prior_arbitrary_no_name() -> None: prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) assert isinstance(prior, xr.Dataset) - assert "var" not in prior + assert "variable" not in prior prior_with = sample_prior( var, @@ -696,7 +702,7 @@ def test_sample_prior_arbitrary_no_name() -> None: ) assert isinstance(prior_with, xr.Dataset) - assert "var" in prior_with + assert "variable" in prior_with def test_create_prior_with_arbitrary() -> None: From ce618a45fc04c57315948b0bbe2b0d666754c01f Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+williambdean@users.noreply.github.com> Date: Tue, 10 Jun 2025 15:56:25 -0400 Subject: [PATCH 4/6] port previous PR (#511) --- pymc_extras/prior.py | 32 ++++++++++++++++++++++++++++++++ tests/test_prior.py | 21 +++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 8d390aaea..9e523ddee 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -84,6 +84,7 @@ def custom_transform(x): import copy from collections.abc import Callable +from functools import partial from inspect import signature from typing import Any, Protocol, runtime_checkable @@ -1354,3 +1355,34 @@ def _is_censored_type(data: dict) -> bool: register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) + + +def __getattr__(name: str): + """Get Prior class through the module. + + Examples + -------- + Create a normal distribution. + + .. code-block:: python + + from pymc_extras.prior import Normal + + dist = Normal(mu=1, sigma=2) + + Create a hierarchical normal distribution. + + .. code-block:: python + + import pymc_extras.prior as pr + + dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel") + samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]}) + + """ + # Protect against doctest + if name == "__wrapped__": + return + + _get_pymc_distribution(name) + return partial(Prior, distribution=name) diff --git a/tests/test_prior.py b/tests/test_prior.py index a7b630f83..f0201a7c4 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -12,6 +12,8 @@ from pydantic import ValidationError from pymc.model_graph import fast_eval +import pymc_extras.prior as pr + from pymc_extras.deserialize import ( DESERIALIZERS, deserialize, @@ -1147,3 +1149,22 @@ def test_scaled_sample_prior() -> None: assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} assert "scaled_var" in prior assert "scaled_var_unscaled" in prior + + +def test_getattr() -> None: + assert pr.Normal() == Prior("Normal") + + +def test_import_directly() -> None: + try: + from pymc_extras.prior import Normal + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + assert Normal() == Prior("Normal") + + +def test_import_incorrect_directly() -> None: + match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'" + with pytest.raises(UnsupportedDistributionError, match=match): + from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401 From 120411f14e9741ecd5c093cd4803ea3048120f05 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian <39779176+Dekermanjian@users.noreply.github.com> Date: Wed, 11 Jun 2025 20:15:10 -0600 Subject: [PATCH 5/6] Use hermitian=True in pt.linalg.pinv in Kalman smoother (#515) --- pymc_extras/statespace/filters/kalman_smoother.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_extras/statespace/filters/kalman_smoother.py b/pymc_extras/statespace/filters/kalman_smoother.py index b22473391..d0b27ed07 100644 --- a/pymc_extras/statespace/filters/kalman_smoother.py +++ b/pymc_extras/statespace/filters/kalman_smoother.py @@ -105,7 +105,7 @@ def smoother_step(self, *args): a_hat, P_hat = self.predict(a, P, T, R, Q) # Use pinv, otherwise P_hat is singular when there is missing data - smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T + smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T a_smooth_next = a + smoother_gain @ (a_smooth - a_hat) P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat) From c099fc4608951098344aecfa55b26881d2f6c8cd Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian <39779176+Dekermanjian@users.noreply.github.com> Date: Thu, 12 Jun 2025 15:50:50 -0600 Subject: [PATCH 6/6] Forecast exogenous vars bug fix (#510) * fixed bug in statespace forecast method when exogenous variables are present. * updated solution to handle input shapes correctly * simplified fix, renamed mu and cov for transparancy and added a check for the graph replacements * Refactor model builder logic out of `forecast` method * made slight change with _build_forecast_model and created a test case * made change to test_build_forecast_model() to ensure data is replaced with pm.set_data method * added additional checks to test_build_forecast_model * added mock_sample_setup_and_teardown to statespace tests --------- Co-authored-by: jessegrabowski --- pymc_extras/statespace/core/statespace.py | 130 +++++++++++++--------- tests/statespace/core/test_statespace.py | 101 ++++++++++++++++- 2 files changed, 174 insertions(+), 57 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 9342ff90d..96e1e9b52 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization( return scenario + def _build_forecast_model( + self, time_index, t0, forecast_index, scenario, filter_output, mvn_method + ): + filter_time_dim = TIME_DIM + temp_coords = self._fit_coords.copy() + + dims = None + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + + t0_idx = np.flatnonzero(time_index == t0)[0] + + temp_coords["data_time"] = time_index + temp_coords[TIME_DIM] = forecast_index + + mu_dims, cov_dims = None, None + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): + mu_dims = ["data_time", ALL_STATE_DIM] + cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] + + with pm.Model(coords=temp_coords) as forecast_model: + (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + data_dims=["data_time", OBS_STATE_DIM], + ) + + group_idx = FILTER_OUTPUT_TYPES.index(filter_output) + mu, cov = grouped_outputs[group_idx] + + sub_dict = { + data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + for data_var in forecast_model.data_vars + } + + missing_data_vars = np.setdiff1d( + ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ) + if missing_data_vars.size > 0: + raise ValueError(f"{missing_data_vars} data used for fitting not found!") + + mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) + + x0 = pm.Deterministic( + "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + ) + P0 = pm.Deterministic( + "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + ) + + _ = LinearGaussianStateSpace( + "forecast", + x0, + P0, + *matrices, + steps=len(forecast_index), + dims=dims, + sequence_names=self.kalman_filter.seq_names, + k_endog=self.k_endog, + append_x0=False, + method=mvn_method, + ) + + return forecast_model + def forecast( self, idata: InferenceData, @@ -2139,8 +2202,6 @@ def forecast( the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`. """ - filter_time_dim = TIME_DIM - _validate_filter_arg(filter_output) compile_kwargs = kwargs.pop("compile_kwargs", {}) @@ -2185,58 +2246,23 @@ def forecast( use_scenario_index=use_scenario_index, ) scenario = self._finalize_scenario_initialization(scenario, forecast_index) - temp_coords = self._fit_coords.copy() - - dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - - t0_idx = np.flatnonzero(time_index == t0)[0] - - temp_coords["data_time"] = time_index - temp_coords[TIME_DIM] = forecast_index - - mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): - mu_dims = ["data_time", ALL_STATE_DIM] - cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] - - with pm.Model(coords=temp_coords) as forecast_model: - (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - scenario=scenario, - data_dims=["data_time", OBS_STATE_DIM], - ) - - for name in self.data_names: - if name in scenario.keys(): - pm.set_data( - {"data": np.zeros((len(forecast_index), self.k_endog))}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - break - group_idx = FILTER_OUTPUT_TYPES.index(filter_output) - mu, cov = grouped_outputs[group_idx] - - x0 = pm.Deterministic( - "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None - ) - P0 = pm.Deterministic( - "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None - ) + forecast_model = self._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output=filter_output, + mvn_method=mvn_method, + ) - _ = LinearGaussianStateSpace( - "forecast", - x0, - P0, - *matrices, - steps=len(forecast_index), - dims=dims, - sequence_names=self.kalman_filter.seq_names, - k_endog=self.k_endog, - append_x0=False, - method=mvn_method, - ) + with forecast_model: + if scenario is not None: + dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + pm.set_data( + scenario | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) forecast_model.rvs_to_initial_values = { k: None for k in forecast_model.rvs_to_initial_values.keys() diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 6a77c1514..bfcd114ae 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -9,6 +9,9 @@ import pytest from numpy.testing import assert_allclose +from pymc.testing import mock_sample_setup_and_teardown +from pytensor.compile import SharedVariable +from pytensor.graph.basic import graph_inputs from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace from pymc_extras.statespace.models import structural as st @@ -30,6 +33,7 @@ floatX = pytensor.config.floatX nile = load_nile_test_data() ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES +mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown) def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None): @@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data): ) beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"]) - exog_ss_mod.build_statespace_graph(exog_data["y"]) + exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True) return struct_model @@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng): @pytest.fixture(scope="session") -def idata(pymc_mod, rng): +def idata(pymc_mod, rng, mock_pymc_sample): with pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -222,7 +226,7 @@ def idata(pymc_mod, rng): @pytest.fixture(scope="session") -def idata_exog(exog_pymc_mod, rng): +def idata_exog(exog_pymc_mod, rng, mock_pymc_sample): with exog_pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng): @pytest.fixture(scope="session") -def idata_no_exog(pymc_mod_no_exog, rng): +def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample): with pymc_mod_no_exog: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng): @pytest.fixture(scope="session") -def idata_no_exog_dt(pymc_mod_no_exog_dt, rng): +def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample): with pymc_mod_no_exog_dt: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): assert_allclose(regression_effect, regression_effect_expected) +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog): + data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} + + scenario = pd.DataFrame( + { + "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), + "x1": rng.choice(2, size=10, replace=True).astype(float), + } + ) + scenario.set_index("date", inplace=True) + + time_index = exog_ss_mod._get_fit_time_index() + t0, forecast_index = exog_ss_mod._build_forecast_index( + time_index=time_index, + start=exog_data.index[-1], + end=scenario.index[-1], + scenario=scenario, + ) + + test_forecast_model = exog_ss_mod._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output="predicted", + mvn_method="svd", + ) + + frozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + assert ( + len(frozen_shared_inputs) == 0 + ) # check there are no non-random generator SharedVariables in the frozen inputs + + unfrozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.forecast_combined]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data + assert len(unfrozen_shared_inputs) == 1 + assert unfrozen_shared_inputs[0].name == "data_exog" + + data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars} + + with test_forecast_model: + dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) + pm.set_data( + {"data_exog": scenario} | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + idata_forecast = pm.sample_posterior_predictive( + idata_exog, var_names=["x0_slice", "P0_slice"] + ) + + np.testing.assert_allclose( + unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1)) + ) # ensure the replaced data matches the exogenous data + + for k in data_before_build_forecast_model.keys(): + assert ( # check that the data needed to init the forecasts doesn't change + data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean() + ) + + # Check that the frozen states and covariances correctly match the sliced index + np.testing.assert_allclose( + idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, + ) + np.testing.assert_allclose( + idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, + ) + + @pytest.mark.filterwarnings("ignore:Provided data contains missing values") @pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")