From 99c04032d21d03c1ef5bda25ffe96777dfd8a9b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:02:24 -0300 Subject: [PATCH 01/53] [pre-commit.ci] pre-commit autoupdate (#187) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.3 → v0.6.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.3...v0.6.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a21dcf..f592534 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.6.4 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 57e8877f8ec73f765153b5b142e3e01c537bab44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:05:40 -0300 Subject: [PATCH 02/53] [pre-commit.ci] pre-commit autoupdate (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.4 → v0.6.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.4...v0.6.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f592534..484b717 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 4e9963bde246c49b7d8bac04150f17c619acf89a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:37:06 -0300 Subject: [PATCH 03/53] [pre-commit.ci] pre-commit autoupdate (#189) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.5 → v0.6.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.5...v0.6.7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 484b717..5a405c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.5 + rev: v0.6.7 hooks: - id: ruff args: ["--fix", "--output-format=full"] From c7144249fdf8c90bb19692e97dcbb455c4f7cc1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:32:51 -0300 Subject: [PATCH 04/53] [pre-commit.ci] pre-commit autoupdate (#190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.7 → v0.6.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.7...v0.6.9) - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a405c6..64b246a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 + rev: v0.6.9 hooks: - id: ruff args: ["--fix", "--output-format=full"] @@ -26,7 +26,7 @@ repos: files: ^pymc_bart/ additional_dependencies: [numpy, pandas-stubs] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer From 824b5824ac82525a8d65d22c297bd031a79b2fdc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:11:33 -0300 Subject: [PATCH 05/53] [pre-commit.ci] pre-commit autoupdate (#191) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.9...v0.7.1) - [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.11.2...v1.13.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64b246a..b879e2f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.1 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy args: [--ignore-missing-imports] From 1741d7dbd02d39aa869a77887f984599737de514 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:38:43 +0100 Subject: [PATCH 06/53] [pre-commit.ci] pre-commit autoupdate (#192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.1 → v0.7.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.1...v0.7.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b879e2f..a99773c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.1 + rev: v0.7.2 hooks: - id: ruff args: ["--fix", "--output-format=full"] From b9f4567ebe1c077f3760a6a08dde41bd62044cda Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 7 Nov 2024 17:00:04 -0300 Subject: [PATCH 07/53] Conform to recent changes in pymc (#194) * conform to recent changes in pymc * update version * fix shapes --- pymc_bart/__init__.py | 2 +- pymc_bart/bart.py | 14 +++++++------- pymc_bart/pgbart.py | 7 +++++-- requirements.txt | 2 +- tests/test_bart.py | 18 +++++++++--------- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index c10b8f8..8774803 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -36,7 +36,7 @@ "plot_pdp", "plot_variable_importance", ] -__version__ = "0.7.0" +__version__ = "0.7.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 969baf4..a21bda5 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -37,19 +37,22 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - ndim_supp = 1 - ndims_params: List[int] = [2, 1, 0, 0, 0, 1] + signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = List[List[List[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed - return dist_params[0].shape[:1] + idx = dist_params[0].ndim - 2 + return [dist_params[0].shape[idx]] @classmethod def rng_fn( # pylint: disable=W0237 - cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None + cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None ): + if not size: + size = None + if not cls.all_trees: if size is not None: return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) @@ -96,9 +99,6 @@ class BART(Distribution): List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - shape: : Optional[Tuple], default None - Specify the output shape. If shape is different from (len(X)) (the default), train a - separate tree for each value in other dimensions. separate_trees : Optional[bool], default False When training multiple trees (by setting a shape parameter), the default behavior is to learn a joint tree structure and only have different leaf values for each. diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 91a9beb..6de7a53 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -114,7 +114,10 @@ class PGBART(ArrayStepShared): name = "pgbart" default_blocked = False generates_stats = True - stats_dtypes = [{"variable_inclusion": object, "tune": bool}] + stats_dtypes_shapes: dict[str, tuple[type, list]] = { + "variable_inclusion": (object, []), + "tune": (bool, []), + } def __init__( # noqa: PLR0915 self, @@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915 def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - upper = min(self.lower + self.batch[~self.tune], self.m) + upper = min(self.lower + self.batch[not self.tune], self.m) tree_ids = range(self.lower, upper) self.lower = upper if upper < self.m else 0 diff --git a/requirements.txt b/requirements.txt index e741cef..ac9bd07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc<=5.16.2 +pymc>=5.16.2, <=5.18 arviz>=0.18.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index dfbd86f..e56735e 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -3,7 +3,7 @@ import pytest from numpy.testing import assert_almost_equal, assert_array_equal from pymc.initial_point import make_initial_point_fn -from pymc.logprob.basic import joint_logp +from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb @@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, return_transformed=False, - default_strategy="moment", + default_strategy="support_point", ) moment = fn(0)["x"] expected = np.asarray(expected) @@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): if check_finite_logp: logp_moment = ( - joint_logp( + transformed_conditional_logp( (model["x"],), rvs_to_values={model["x"]: pm.math.constant(moment)}, rvs_to_transforms={}, @@ -53,7 +53,7 @@ def test_bart_vi(response): mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) var_imp = ( idata.sample_stats["variable_inclusion"] .stack(samples=("chain", "draw")) @@ -77,8 +77,8 @@ def test_missing_data(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) + pm.Normal("y", mu, sigma, observed=Y) + pm.sample(tune=100, draws=100, chains=1, random_seed=3415) @pytest.mark.parametrize( @@ -91,7 +91,7 @@ def test_shared_variable(response): Y = np.random.normal(0, 1, size=50) with pm.Model() as model: - data_X = pm.MutableData("data_X", X) + data_X = pm.Data("data_X", X) mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) @@ -116,7 +116,7 @@ def test_shape(response): with pm.Model() as model: w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=50, draws=10, random_seed=3415) assert model.initial_point()["w"].shape == (2, 250) assert idata.posterior.coords["w_dim_0"].data.size == 2 @@ -133,7 +133,7 @@ class TestUtils: mu = pmb.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) def test_sample_posterior(self): all_trees = self.mu.owner.op.all_trees From 282b2ef230784290cf669d8949711834b4fcd873 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 08:25:48 -0300 Subject: [PATCH 08/53] [pre-commit.ci] pre-commit autoupdate (#195) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.2 → v0.7.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.2...v0.7.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a99773c..f05b6db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.7.4 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 4ef2dd0a74c985c634543b4b489e60e7bfc08d41 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 25 Nov 2024 12:32:41 -0300 Subject: [PATCH 09/53] Add new vi plots (#196) * add new vi plots * fix tests * add missing file --- pymc_bart/__init__.py | 6 + pymc_bart/utils.py | 371 +++++++++++++++++++++++++++++++----------- tests/test_bart.py | 9 +- 3 files changed, 285 insertions(+), 101 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 8774803..18fe054 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -17,11 +17,14 @@ from pymc_bart.pgbart import PGBART from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( + compute_variable_importance, plot_convergence, plot_dependence, plot_ice, plot_pdp, + plot_scatter_submodels, plot_variable_importance, + plot_variable_inclusion, ) __all__ = [ @@ -30,11 +33,14 @@ "ContinuousSplitRule", "OneHotSplitRule", "SubsetSplitRule", + "compute_variable_importance", "plot_convergence", "plot_dependence", "plot_ice", "plot_pdp", + "plot_scatter_submodels", "plot_variable_importance", + "plot_variable_inclusion", ] __version__ = "0.7.1" diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index a50f2d9..e8c60bb 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-branches """Utility function for variable selection and bart interpretability.""" import warnings @@ -248,7 +249,7 @@ def identity(x): _, ) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances) idx_s = list(range(X.shape[0])) @@ -270,7 +271,6 @@ def identity(x): new_x = fake_X[:, var] p_d = np.array(y_pred) - print(p_d.shape) for s_i in range(shape): if centered: @@ -398,7 +398,7 @@ def identity(x): xs_values, ) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) @@ -447,7 +447,7 @@ def identity(x): return axes -def _get_axes( +def _create_figure_axes( bartrv: Variable, var_idx: List[int], grid: str = "long", @@ -492,29 +492,8 @@ def _get_axes( n_plots = len(var_idx) * shape if ax is None: - if grid == "long": - fig, axes = plt.subplots(n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif grid == "wide": - fig, axes = plt.subplots(1, n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif isinstance(grid, tuple): - grid_size = grid[0] * grid[1] - if n_plots > grid_size: - warnings.warn( - """The grid is smaller than the number of available variables to plot. - Automatically adjusting the grid size.""" - ) - grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) - - fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) - axes = np.ravel(axes) + fig, axes = _get_axes(grid, n_plots, False, sharey, figsize) - for i in range(n_plots, len(axes)): - fig.delaxes(axes[i]) - axes = axes[:n_plots] elif isinstance(ax, np.ndarray): axes = ax fig = ax[0].get_figure() @@ -525,6 +504,33 @@ def _get_axes( return fig, axes, shape +def _get_axes(grid, n_plots, sharex, sharey, figsize): + if grid == "long": + fig, axes = plt.subplots(n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif grid == "wide": + fig, axes = plt.subplots(1, n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif isinstance(grid, tuple): + grid_size = grid[0] * grid[1] + if n_plots > grid_size: + warnings.warn( + """The grid is smaller than the number of available variables to plot. + Automatically adjusting the grid size.""" + ) + grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) + + fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) + axes = np.ravel(axes) + + for i in range(n_plots, len(axes)): + fig.delaxes(axes[i]) + axes = axes[:n_plots] + return fig, axes + + def _prepare_plot_data( X: npt.NDArray[np.float64], Y: Optional[npt.NDArray[np.float64]] = None, @@ -693,18 +699,86 @@ def _smooth_mean( return x_data, y_data -def plot_variable_importance( # noqa: PLR0915 +def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): + """ + Plot normalized variable inclusion from BART model. + + Parameters + ---------- + idata: InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : npt.NDArray[np.float64] + The covariate matrix. + labels : Optional[List[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + figsize : tuple + Figure size. If None it will be defined automatically. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color: matplotlib valid color for VI + - marker: matplotlib valid marker for VI + - ls: matplotlib valid linestyle for the VI line + - rotation: float, rotation of the x-axis labels + ax : axes + Matplotlib axes. + + Returns + ------- + idxs: indexes of the covariates from higher to lower relative importance + axes: matplotlib axes + """ + if plot_kwargs is None: + plot_kwargs = {} + + VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + VIs = VIs / VIs.sum() + idxs = np.argsort(VIs) + + indices = idxs[::-1] + n_vars = len(indices) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + + if labels is None: + labels = np.arange(n_vars).astype(str) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + + ticks = np.arange(n_vars, dtype=int) + + if figsize is None: + figsize = (8, 3) + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + ax.plot( + VIs[indices], + color=plot_kwargs.get("color", "k"), + marker=plot_kwargs.get("marker", "o"), + ls=plot_kwargs.get("ls", "-"), + ) + + ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0)) + + ax.axhline(1 / n_vars, color="0.5", linestyle="--") + ax.set_ylim(0, 1) + + return idxs, ax + + +def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: az.InferenceData, bartrv: Variable, X: npt.NDArray[np.float64], - labels: Optional[List[str]] = None, method: str = "VI", - figsize: Optional[Tuple[float, float]] = None, + fixed: int = 0, samples: int = 50, random_seed: Optional[int] = None, - plot_kwargs: Optional[Dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, -) -> Tuple[List[int], Union[List[plt.Axes], Any]]: +) -> dict[str, object]: """ Estimates variable importance from the BART-posterior. @@ -716,87 +790,74 @@ def plot_variable_importance( # noqa: PLR0915 BART variable once the model that include it has been fitted. X : npt.NDArray[np.float64] The covariate matrix. - labels : Optional[List[str]] - List of the names of the covariates. If X is a DataFrame the names of the covariables will - be taken from it and this argument will be ignored. method : str - Method used to rank variables. Available options are "VI" (default) and "backward". + Method used to rank variables. Available options are "VI" (default), "backward" + and "backward_VI". The R squared will be computed following this ranking. "VI" counts how many times each variable is included in the posterior distribution of trees. "backward" uses a backward search based on the R squared. - VI requieres less computation time. - figsize : tuple - Figure size. If None it will be defined automatically. + "backward_VI" combines both methods with the backward search excluding + the ``fixed`` number of variables with the lowest variable inclusion. + "VI" is the fastest method, while "backward" is the slowest. + fixed : Optional[int] + Number of variables to fix in the backward search. Defaults to None. + Must be greater than 0 and less than the number of variables. + Ignored if method is "VI" or "backward". samples : int - Number of predictions used to compute correlation for subsets of variables. Defaults to 100 + Number of predictions used to compute correlation for subsets of variables. Defaults to 50 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. - plot_kwargs : dict - Additional keyword arguments for the plot. Defaults to None. - Valid keys are: - - color_r2: matplotlib valid color for error bars - - marker_r2: matplotlib valid marker for the mean R squared - - marker_fc_r2: matplotlib valid marker face color for the mean R squared - - ls_ref: matplotlib valid linestyle for the reference line - - color_ref: matplotlib valid color for the reference line - ax : axes - Matplotlib axes. Returns ------- - idxs: indexes of the covariates from higher to lower relative importance - axes: matplotlib axes + vi_results: dictionary """ + if method not in ["VI", "backward", "backward_VI"]: + raise ValueError("method must be 'VI', 'backward' or 'backward_VI'") + rng = np.random.default_rng(random_seed) all_trees = bartrv.owner.op.all_trees - if plot_kwargs is None: - plot_kwargs = {} - if bartrv.ndim == 1: # type: ignore shape = 1 else: shape = bartrv.eval().shape[0] if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns X = X.to_numpy() n_vars = X.shape[1] - - if figsize is None: - figsize = (8, 3) - - if ax is None: - _, ax = plt.subplots(1, 1, figsize=figsize) - - if labels is None: - labels_ary = np.arange(n_vars).astype(str) + r2_mean = np.zeros(n_vars) + r2_hdi = np.zeros((n_vars, 2)) + preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) + + if method == "backward_VI": + if fixed >= n_vars: + raise ValueError("fixed must be less than the number of variables") + elif fixed < 1: + raise ValueError("fixed must be greater than 0") + init = fixed + 1 else: - labels_ary = np.array(labels) - - ticks = np.arange(n_vars, dtype=int) + fixed = 0 + init = 0 predicted_all = _sample_posterior( all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) - r_2_ref = np.array( - [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] - ) - - if method == "VI": + if method in ["VI", "backward_VI"]: idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values ) subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))] subsets.append(None) # type: ignore + if method == "backward_VI": + subsets = subsets[-init:] + indices: List[int] = list(idxs[::-1]) - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( all_trees=all_trees, @@ -811,19 +872,24 @@ def plot_variable_importance( # noqa: PLR0915 ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) - - elif method == "backward": - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) - - variables = set(range(n_vars)) - least_important_vars: List[int] = [] - indices = [] + preds[idx] = predicted_subset.squeeze() + + if method in ["backward", "backward_VI"]: + if method == "backward_VI": + least_important_vars: List[int] = indices[-fixed:] + r2_mean_vi = r2_mean[:init] + r2_hdi_vi = r2_hdi[:init] + preds_vi = preds[:init] + r2_mean = np.zeros(n_vars - fixed - 1) + r2_hdi = np.zeros((n_vars - fixed - 1, 2)) + preds = np.zeros((n_vars - fixed - 1, samples, bartrv.eval().shape[0])) + else: + least_important_vars = [] # Iterate over each variable to determine its contribution # least_important_vars tracks the variable with the lowest contribution - # at the current stage. One new varible is added at each iteration. - for i_var in range(n_vars): + # at the current stage. One new variable is added at each iteration. + for i_var in range(init, n_vars): # Generate all possible subsets by adding one variable at a time to # least_important_vars subsets = generate_sequences(n_vars, i_var, least_important_vars) @@ -851,30 +917,116 @@ def plot_variable_importance( # noqa: PLR0915 max_r_2 = mean_r_2 least_important_subset = subset r_2_without_least_important_vars = r_2 + least_important_samples = predicted_subset # Save values for plotting later - r2_mean[i_var] = max_r_2 - r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars) + r2_mean[i_var - init] = max_r_2 + r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars) + preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable - least_important_vars += least_important_subset + for var_i in least_important_subset: + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + # Add the remaining variables to the list of least important variables + for var_i in range(n_vars): + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + if method == "backward_VI": + r2_mean = np.concatenate((r2_mean[::-1], r2_mean_vi)) + r2_hdi = np.concatenate((r2_hdi[::-1], r2_hdi_vi)) + preds = np.concatenate((preds[::-1], preds_vi)) + else: + r2_mean = r2_mean[::-1] + r2_hdi = r2_hdi[::-1] + preds = preds[::-1] + + indices = least_important_vars[::-1] + + vi_results = { + "indices": indices, + "r2_mean": r2_mean, + "r2_hdi": r2_hdi, + "preds": preds, + "preds_all": predicted_all.squeeze(), + } + return vi_results + + +def plot_variable_importance( + vi_results: dict, + X: npt.NDArray[np.float64], + labels=None, + figsize=None, + plot_kwargs: Optional[Dict[str, Any]] = None, + ax: Optional[plt.Axes] = None, +): + """ + Estimates variable importance from the BART-posterior. - # add index of removed variable - indices += list(set(least_important_subset) - set(indices)) + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + X : npt.NDArray[np.float64] + The covariate matrix. + labels : Optional[List[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line + - rotation: float, rotation angle of the x-axis labels. Defaults to 0. + ax : axes + Matplotlib axes. - # add remaining index - indices += list(set(variables) - set(least_important_vars)) + Returns + ------- + axes: matplotlib axes + """ - indices = indices[::-1] - r2_mean = r2_mean[::-1] - r2_hdi = r2_hdi[::-1] + indices = vi_results["indices"] + r2_mean = vi_results["r2_mean"] + r2_hdi = vi_results["r2_hdi"] + preds = vi_results["preds"] + preds_all = vi_results["preds_all"] + samples = preds.shape[1] - new_labels = [ - "+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices]) - ] + n_vars = len(indices) + ticks = np.arange(n_vars, dtype=int) + + if plot_kwargs is None: + plot_kwargs = {} + + if figsize is None: + figsize = (8, 3) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + X = X.to_numpy() + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + if labels is None: + labels = np.arange(n_vars).astype(str) + else: + labels = np.asarray(labels) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + + r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None) r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None) + ax.errorbar( ticks, r2_mean, @@ -903,7 +1055,28 @@ def plot_variable_importance( # noqa: PLR0915 ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) - return indices, ax + return ax + + +def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): + indices = vi_results["indices"] + preds = vi_results["preds"] + preds_all = vi_results["preds_all"] + + if axes is None: + _, axes = _get_axes(grid, len(indices), False, True, None) + + func = None + if func is not None: + preds = func(preds) + preds_all = func(preds_all) + + min_ = min(np.min(preds), np.min(preds_all)) + max_ = max(np.max(preds), np.max(preds_all)) + + for pred, ax in zip(preds, axes.ravel()): + ax.plot(pred, preds_all, ".", color="C0", alpha=0.1) + ax.axline([min_, min_], [max_, max_], color="0.5") def generate_sequences(n_vars, i_var, include): diff --git a/tests/test_bart.py b/tests/test_bart.py index e56735e..c10fc94 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -184,12 +184,17 @@ def test_pdp(self, kwargs): @pytest.mark.parametrize( "kwargs", [ - {}, + {"samples": 50}, {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, ], ) def test_vi(self, kwargs): - pmb.plot_variable_importance(self.idata, X=self.X, bartrv=self.mu, **kwargs) + samples = kwargs.pop("samples") + vi_results = pmb.compute_variable_importance( + self.idata, bartrv=self.mu, X=self.X, samples=samples + ) + pmb.plot_variable_importance(vi_results, X=self.X, **kwargs) + pmb.plot_scatter_submodels(vi_results) def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas") From 40f1220d66576d5b1a5ba18653d3756685f6e712 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Tue, 26 Nov 2024 21:19:30 -0300 Subject: [PATCH 10/53] improve docs, aesthetics and functionality (#198) * improve docs, aesthetics and functionality * remove X argument from plots --- pymc_bart/utils.py | 87 +++++++++++++++++++++++++++++++++++++--------- tests/test_bart.py | 4 +-- 2 files changed, 72 insertions(+), 19 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e8c60bb..31cc28f 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 else: shape = bartrv.eval().shape[0] + n_vars = X.shape[1] + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns X = X.to_numpy() + else: + labels = np.arange(n_vars).astype(str) - n_vars = X.shape[1] r2_mean = np.zeros(n_vars) r2_hdi = np.zeros((n_vars, 2)) preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) @@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 vi_results = { "indices": indices, + "labels": labels[indices], "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, @@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 def plot_variable_importance( vi_results: dict, - X: npt.NDArray[np.float64], labels=None, figsize=None, plot_kwargs: Optional[Dict[str, Any]] = None, @@ -1008,19 +1012,13 @@ def plot_variable_importance( if figsize is None: figsize = (8, 3) - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns - X = X.to_numpy() - if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) if labels is None: - labels = np.arange(n_vars).astype(str) - else: - labels = np.asarray(labels) + labels = vi_results["labels"] - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) @@ -1048,7 +1046,7 @@ def plot_variable_importance( ) ax.set_xticks( ticks, - new_labels, + labels, rotation=plot_kwargs.get("rotation", 0), ) ax.set_ylabel("R²", rotation=0, labelpad=12) @@ -1058,15 +1056,57 @@ def plot_variable_importance( return ax -def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): +def plot_scatter_submodels( + vi_results: dict, + func: Optional[Callable] = None, + grid: str = "long", + labels=None, + figsize: Optional[Tuple[float, float]] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, + axes: Optional[plt.Axes] = None, +): + """ + Plot submodel's predictions against reference-model's predictions. + + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + func : Optional[Callable], by default None. + Arbitrary function to apply to the predictions. Defaults to the identity function. + grid : str or tuple + How to arrange the subplots. Defaults to "long", one subplot below the other. + Other options are "wide", one subplot next to each other or a tuple indicating the number + of rows and columns. + labels : Optional[List[str]] + List of the names of the covariates. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_ref: matplotlib valid color for the 45 degree line + - color_scatter: matplotlib valid color for the scatter plot + axes : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ indices = vi_results["indices"] preds = vi_results["preds"] preds_all = vi_results["preds_all"] if axes is None: - _, axes = _get_axes(grid, len(indices), False, True, None) + _, axes = _get_axes(grid, len(indices), True, True, figsize) + + if plot_kwargs is None: + plot_kwargs = {} + + if labels is None: + labels = vi_results["labels"] + + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] - func = None if func is not None: preds = func(preds) preds_all = func(preds_all) @@ -1074,9 +1114,22 @@ def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, ax in zip(preds, axes.ravel()): - ax.plot(pred, preds_all, ".", color="C0", alpha=0.1) - ax.axline([min_, min_], [max_, max_], color="0.5") + for pred, x_label, ax in zip(preds, labels, axes.ravel()): + ax.plot( + pred, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax.set_xlabel(x_label) + ax.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) def generate_sequences(n_vars, i_var, include): diff --git a/tests/test_bart.py b/tests/test_bart.py index c10fc94..c64811a 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -193,8 +193,8 @@ def test_vi(self, kwargs): vi_results = pmb.compute_variable_importance( self.idata, bartrv=self.mu, X=self.X, samples=samples ) - pmb.plot_variable_importance(vi_results, X=self.X, **kwargs) - pmb.plot_scatter_submodels(vi_results) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas") From 9ec4de89cd561944dae5bf9d7a34fddf37d6b224 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 28 Nov 2024 17:45:50 -0300 Subject: [PATCH 11/53] lint updates (#199) * lint updates * use built-in types --- .pre-commit-config.yaml | 2 +- pymc_bart/bart.py | 14 +++--- pymc_bart/pgbart.py | 36 +++++++------- pymc_bart/tree.py | 41 ++++++++-------- pymc_bart/utils.py | 106 ++++++++++++++++++++-------------------- pyproject.toml | 3 +- 6 files changed, 101 insertions(+), 101 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f05b6db..8832b06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.8.0 hooks: - id: ruff args: ["--fix", "--output-format=full"] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index a21bda5..16a856c 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -16,7 +16,7 @@ import warnings from multiprocessing import Manager -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import numpy.typing as npt @@ -39,8 +39,8 @@ class BARTRV(RandomVariable): name: str = "BART" signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" - _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = List[List[List[Tree]]] + _print_name: tuple[str, str] = ("BART", "\\operatorname{BART}") + all_trees = list[list[list[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed idx = dist_params[0].ndim - 2 @@ -92,10 +92,10 @@ class BART(Distribution): beta : float Controls the prior probability over the number of leaves of the trees. Should be positive. - split_prior : Optional[List[float]], default None. + split_prior : Optional[list[float]], default None. List of positive numbers, one per column in input data. Defaults to None, all covariates have the same prior probability to be selected. - split_rules : Optional[List[SplitRule]], default None + split_rules : Optional[list[SplitRule]], default None List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. @@ -126,7 +126,7 @@ def __new__( beta: float = 2.0, response: str = "constant", split_prior: Optional[npt.NDArray[np.float64]] = None, - split_rules: Optional[List[SplitRule]] = None, + split_rules: Optional[list[SplitRule]] = None, separate_trees: Optional[bool] = False, **kwargs, ): @@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs): def preprocess_xy( X: TensorLike, Y: TensorLike -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6de7a53..6a7e26e 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -43,7 +43,7 @@ class ParticleTree: def __init__(self, tree: Tree): self.tree: Tree = tree.copy() - self.expansion_nodes: List[int] = [0] + self.expansion_nodes: list[int] = [0] self.log_weight: float = 0 def copy(self) -> "ParticleTree": @@ -123,7 +123,7 @@ def __init__( # noqa: PLR0915 self, vars=None, # pylint: disable=redefined-builtin num_particles: int = 10, - batch: Tuple[float, float] = (0.1, 0.1), + batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, ): model = modelcontext(model) @@ -310,7 +310,7 @@ def astep(self, _): stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] - def normalize(self, particles: List[ParticleTree]) -> float: + def normalize(self, particles: list[ParticleTree]) -> float: """ Use softmax to get normalized_weights. """ @@ -321,16 +321,16 @@ def normalize(self, particles: List[ParticleTree]) -> float: return wei / wei.sum() def resample( - self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64] - ) -> List[ParticleTree]: + self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + ) -> list[ParticleTree]: """ Use systematic resample for all but the first particle Ensure particles are copied only if needed. """ new_indices = self.systematic(normalized_weights) + 1 - seen: List[int] = [] - new_particles: List[ParticleTree] = [] + seen: list[int] = [] + new_particles: list[ParticleTree] = [] for idx in new_indices: if idx in seen: new_particles.append(particles[idx].copy()) @@ -343,8 +343,8 @@ def resample( return particles def get_particle_tree( - self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64] - ) -> Tuple[ParticleTree, Tree]: + self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + ) -> tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree """ @@ -367,12 +367,12 @@ def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw return inverse_cdf(single_uniform, normalized_weights) - def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]: + def init_particles(self, tree_id: int, odim: int) -> list[ParticleTree]: """Initialize particles.""" p0: ParticleTree = self.all_particles[odim][tree_id] # The old tree does not grow so we update the weight only once self.update_weight(p0, odim) - particles: List[ParticleTree] = [p0] + particles: list[ParticleTree] = [p0] particles.extend(ParticleTree(self.a_tree) for _ in self.indices) return particles @@ -419,7 +419,7 @@ def _update( mean: npt.NDArray[np.float64], m_2: npt.NDArray[np.float64], new_value: npt.NDArray[np.float64], -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]: delta = new_value - mean mean += delta / count delta2 = new_value - mean @@ -439,7 +439,7 @@ def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None: """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - def rvs(self) -> Union[int, Tuple[int, float]]: + def rvs(self) -> Union[int, tuple[int, float]]: rnd: float = np.random.random() for i, val in self.enu: if rnd <= val: @@ -447,7 +447,7 @@ def rvs(self) -> Union[int, Tuple[int, float]]: return self.enu[-1] -def compute_prior_probability(alpha: int, beta: int) -> List[float]: +def compute_prior_probability(alpha: int, beta: int) -> list[float]: """ Calculate the probability of the node being a leaf node (1 - p(being split node)). @@ -460,7 +460,7 @@ def compute_prior_probability(alpha: int, beta: int) -> List[float]: ------- list with probabilities for leaf nodes """ - prior_leaf_prob: List[float] = [0] + prior_leaf_prob: list[float] = [0] depth = 0 while prior_leaf_prob[-1] < 0.9999: prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta)))) @@ -549,7 +549,7 @@ def draw_leaf_value( norm: npt.NDArray[np.float64], shape: int, response: str, -) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]: """Draw Gaussian distributed leaf values.""" linear_params = None mu_mean = np.empty(shape) @@ -590,7 +590,7 @@ def fast_linear_fit( y: npt.NDArray[np.float64], m: int, norm: npt.NDArray[np.float64], -) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]: n = len(x) y = y / m + np.expand_dims(norm, axis=1) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 0e0a35c..7655175 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from functools import lru_cache -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -30,7 +31,7 @@ class Node: value : npt.NDArray[np.float64] idx_data_points : Optional[npt.NDArray[np.int_]] idx_split_variable : int - linear_params: Optional[List[float]] = None + linear_params: Optional[list[float]] = None """ __slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params" @@ -41,7 +42,7 @@ def __init__( nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray[np.float64]]] = None, ) -> None: self.value = value self.nvalue = nvalue @@ -56,7 +57,7 @@ def new_leaf_node( nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray[np.float64]]] = None, ) -> "Node": return cls( value=value, @@ -94,7 +95,7 @@ class Tree: Attributes ---------- - tree_structure : Dict[int, Node] + tree_structure : dict[int, Node] A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary's keys are integers that represent the nodes position. @@ -102,11 +103,11 @@ class Tree: of the tree itself. output: Optional[npt.NDArray[np.float64]] Array of shape number of observations, shape - split_rules : List[SplitRule] + split_rules : list[SplitRule] List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - idx_leaf_nodes : Optional[List[int]], by default None. + idx_leaf_nodes : Optional[list[int]], by default None. Array with the index of the leaf nodes of the tree. Parameters @@ -120,10 +121,10 @@ class Tree: def __init__( self, - tree_structure: Dict[int, Node], + tree_structure: dict[int, Node], output: npt.NDArray[np.float64], - split_rules: List[SplitRule], - idx_leaf_nodes: Optional[List[int]] = None, + split_rules: list[SplitRule], + idx_leaf_nodes: Optional[list[int]] = None, ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -137,7 +138,7 @@ def new_tree( idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, - split_rules: List[SplitRule], + split_rules: list[SplitRule], ) -> "Tree": return cls( tree_structure={ @@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None: self.set_node(index, node) def copy(self) -> "Tree": - tree: Dict[int, Node] = { + tree: dict[int, Node] = { k: Node( value=v.value, nvalue=v.nvalue, @@ -199,7 +200,7 @@ def grow_leaf_node( self.idx_leaf_nodes.remove(index_leaf_node) def trim(self) -> "Tree": - tree: Dict[int, Node] = { + tree: dict[int, Node] = { k: Node( value=v.value, nvalue=v.nvalue, @@ -233,7 +234,7 @@ def _predict(self) -> npt.NDArray[np.float64]: def predict( self, x: npt.NDArray[np.float64], - excluded: Optional[List[int]] = None, + excluded: Optional[list[int]] = None, shape: int = 1, ) -> npt.NDArray[np.float64]: """ @@ -243,7 +244,7 @@ def predict( ---------- x : npt.NDArray[np.float64] Unobserved point - excluded: Optional[List[int]] + excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns @@ -259,8 +260,8 @@ def predict( def _traverse_tree( self, X: npt.NDArray[np.float64], - excluded: Optional[List[int]] = None, - shape: Union[int, Tuple[int, ...]] = 1, + excluded: Optional[list[int]] = None, + shape: Union[int, tuple[int, ...]] = 1, ) -> npt.NDArray[np.float64]: """ Traverse the tree starting from the root node given an (un)observed point. @@ -273,7 +274,7 @@ def _traverse_tree( Index of the node to start the traversal from split_variable : int Index of the variable used to split the node - excluded: Optional[List[int]] + excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns @@ -327,14 +328,14 @@ def _traverse_tree( return p_d def _traverse_leaf_values( - self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int + self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : List[npt.NDArray[np.float64]] + leaf_values : list[npt.NDArray[np.float64]] node_index : int """ node = self.get_node(node_index) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 31cc28f..10b5dfd 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -2,7 +2,7 @@ """Utility function for variable selection and bart interpretability.""" import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import arviz as az import matplotlib.pyplot as plt @@ -21,11 +21,11 @@ def _sample_posterior( - all_trees: List[List[Tree]], + all_trees: list[list[Tree]], X: TensorLike, rng: np.random.Generator, - size: Optional[Union[int, Tuple[int, ...]]] = None, - excluded: Optional[List[int]] = None, + size: Optional[Union[int, tuple[int, ...]]] = None, + excluded: Optional[list[int]] = None, shape: int = 1, ) -> npt.NDArray[np.float64]: """ @@ -50,7 +50,7 @@ def _sample_posterior( X = X.eval() if size is None: - size_iter: Union[List, Tuple] = (1,) + size_iter: Union[list, tuple] = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -79,9 +79,9 @@ def plot_convergence( idata: az.InferenceData, var_name: Optional[str] = None, kind: str = "ecdf", - figsize: Optional[Tuple[float, float]] = None, + figsize: Optional[tuple[float, float]] = None, ax=None, -) -> List[plt.Axes]: +) -> list[plt.Axes]: """ Plot convergence diagnostics. @@ -93,14 +93,14 @@ def plot_convergence( Name of the BART variable to plot. Defaults to None. kind : str Type of plot to display. Options are "ecdf" (default) and "kde". - figsize : Optional[Tuple[float, float]], by default None. + figsize : Optional[tuple[float, float]], by default None. Figure size. Defaults to None. ax : matplotlib axes Axes on which to plot. Defaults to None. Returns ------- - List[ax] : matplotlib axes + list[ax] : matplotlib axes """ ess_threshold = idata["posterior"]["chain"].size * 100 ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) @@ -157,8 +157,8 @@ def plot_ice( bartrv: Variable, X: npt.NDArray[np.float64], Y: Optional[npt.NDArray[np.float64]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, + var_idx: Optional[list[int]] = None, + var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, centered: Optional[bool] = True, samples: int = 100, @@ -170,10 +170,10 @@ def plot_ice( color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[Tuple[float, float]] = None, - smooth_kwargs: Optional[Dict[str, Any]] = None, + figsize: Optional[tuple[float, float]] = None, + smooth_kwargs: Optional[dict[str, Any]] = None, ax: Optional[plt.Axes] = None, -) -> List[plt.Axes]: +) -> list[plt.Axes]: """ Individual conditional expectation plot. @@ -185,9 +185,9 @@ def plot_ice( The covariate matrix. Y : Optional[npt.NDArray[np.float64]], by default None. The response vector. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. @@ -302,9 +302,9 @@ def plot_pdp( X: npt.NDArray[np.float64], Y: Optional[npt.NDArray[np.float64]] = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, List[float]]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, + xs_values: Optional[Union[int, list[float]]] = None, + var_idx: Optional[list[int]] = None, + var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, samples: int = 200, random_seed: Optional[int] = None, @@ -314,10 +314,10 @@ def plot_pdp( color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[Tuple[float, float]] = None, - smooth_kwargs: Optional[Dict[str, Any]] = None, + figsize: Optional[tuple[float, float]] = None, + smooth_kwargs: Optional[dict[str, Any]] = None, ax: Optional[plt.Axes] = None, -) -> List[plt.Axes]: +) -> list[plt.Axes]: """ Partial dependence plot. @@ -334,14 +334,14 @@ def plot_pdp( evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified quantiles of X. "insample", the evaluation is done at the values of X. For discrete variables these options are ommited. - xs_values : Optional[Union[int, List[float]]], by default None. + xs_values : Optional[Union[int, list[float]]], by default None. Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. Ignored when ``xs_interval="insample"``. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. @@ -449,12 +449,12 @@ def identity(x): def _create_figure_axes( bartrv: Variable, - var_idx: List[int], + var_idx: list[int], grid: str = "long", sharey: bool = True, - figsize: Optional[Tuple[float, float]] = None, + figsize: Optional[tuple[float, float]] = None, ax: Optional[plt.Axes] = None, -) -> Tuple[plt.Figure, List[plt.Axes], int]: +) -> tuple[plt.Figure, list[plt.Axes], int]: """ Create and return the figure and axes objects for plotting the variables. @@ -464,9 +464,9 @@ def _create_figure_axes( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number of @@ -481,7 +481,7 @@ def _create_figure_axes( Returns ------- - Tuple[plt.Figure, List[plt.Axes], int] + tuple[plt.Figure, list[plt.Axes], int] A tuple containing the figure object, list of axes objects, and the shape value. """ if bartrv.ndim == 1: # type: ignore @@ -535,18 +535,18 @@ def _prepare_plot_data( X: npt.NDArray[np.float64], Y: Optional[npt.NDArray[np.float64]] = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, List[float]]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, -) -> Tuple[ + xs_values: Optional[Union[int, list[float]]] = None, + var_idx: Optional[list[int]] = None, + var_discrete: Optional[list[int]] = None, +) -> tuple[ npt.NDArray[np.float64], - List[str], + list[str], str, - List[int], - List[int], - List[int], + list[int], + list[int], + list[int], str, - Union[int, None, List[float]], + Union[int, None, list[float]], ]: """ Prepare data for plotting. @@ -627,7 +627,7 @@ def _prepare_plot_data( def _create_pdp_data( X: npt.NDArray[np.float64], xs_interval: str, - xs_values: Optional[Union[int, List[float]]] = None, + xs_values: Optional[Union[int, list[float]]] = None, ) -> npt.NDArray[np.float64]: """ Create data for partial dependence plot. @@ -663,8 +663,8 @@ def _smooth_mean( new_x: npt.NDArray[np.float64], p_di: npt.NDArray[np.float64], kind: str = "pdp", - smooth_kwargs: Optional[Dict[str, Any]] = None, -) -> Tuple[np.ndarray, np.ndarray]: + smooth_kwargs: Optional[dict[str, Any]] = None, +) -> tuple[np.ndarray, np.ndarray]: """ Smooth the mean data for plotting. @@ -676,12 +676,12 @@ def _smooth_mean( The distribution of partial dependence from which to comptue the smoothed mean. kind : str, optional The type of plot. Possible values are "pdp" or "ice". - smooth_kwargs : Optional[Dict[str, Any]], optional + smooth_kwargs : Optional[dict[str, Any]], optional Additional keyword arguments for the smoothing function. Defaults to None. Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] A tuple containing a grid for the x-axis data and the corresponding smoothed y-axis data. """ @@ -709,7 +709,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non InferenceData containing a collection of BART_trees in sample_stats group X : npt.NDArray[np.float64] The covariate matrix. - labels : Optional[List[str]] + labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. figsize : tuple @@ -860,7 +860,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 if method == "backward_VI": subsets = subsets[-init:] - indices: List[int] = list(idxs[::-1]) + indices: list[int] = list(idxs[::-1]) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( @@ -880,7 +880,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 if method in ["backward", "backward_VI"]: if method == "backward_VI": - least_important_vars: List[int] = indices[-fixed:] + least_important_vars: list[int] = indices[-fixed:] r2_mean_vi = r2_mean[:init] r2_hdi_vi = r2_hdi[:init] preds_vi = preds[:init] @@ -964,7 +964,7 @@ def plot_variable_importance( vi_results: dict, labels=None, figsize=None, - plot_kwargs: Optional[Dict[str, Any]] = None, + plot_kwargs: Optional[dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ): """ @@ -976,7 +976,7 @@ def plot_variable_importance( Dictionary computed with `compute_variable_importance` X : npt.NDArray[np.float64] The covariate matrix. - labels : Optional[List[str]] + labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. plot_kwargs : dict @@ -1061,8 +1061,8 @@ def plot_scatter_submodels( func: Optional[Callable] = None, grid: str = "long", labels=None, - figsize: Optional[Tuple[float, float]] = None, - plot_kwargs: Optional[Dict[str, Any]] = None, + figsize: Optional[tuple[float, float]] = None, + plot_kwargs: Optional[dict[str, Any]] = None, axes: Optional[plt.Axes] = None, ): """ @@ -1078,7 +1078,7 @@ def plot_scatter_submodels( How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number of rows and columns. - labels : Optional[List[str]] + labels : Optional[list[str]] List of the names of the covariates. plot_kwargs : dict Additional keyword arguments for the plot. Defaults to None. diff --git a/pyproject.toml b/pyproject.toml index 165ed67..bc94137 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ line-length = 100 [tool.ruff.lint] select = ["E", "F", "I", "PL", "UP", "W"] -ignore-init-module-imports = true ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. ] @@ -17,7 +16,7 @@ ignore = [ max-args = 19 max-branches = 15 -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "docs/conf.py" = ["E501", "F541"] "tests/test_*.py" = ["F841"] From b20d074cdbea4c06d1a599a759e107506418e7ae Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Fri, 29 Nov 2024 10:09:19 -0300 Subject: [PATCH 12/53] add submodels arguments to plot subsets (#200) --- pymc_bart/utils.py | 76 ++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 10b5dfd..e10a511 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non Parameters ---------- - idata: InferenceData + idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group X : npt.NDArray[np.float64] The covariate matrix. @@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 Parameters ---------- - idata: InferenceData + idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. @@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 indices = least_important_vars[::-1] + labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]) + vi_results = { - "indices": indices, + "indices": np.asarray(indices), "labels": labels[indices], "r2_mean": r2_mean, "r2_hdi": r2_hdi, @@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 def plot_variable_importance( vi_results: dict, - labels=None, - figsize=None, + submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, + labels: Optional[list[str]] = None, + figsize: Optional[tuple[float, float]] = None, plot_kwargs: Optional[dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ): @@ -974,8 +977,11 @@ def plot_variable_importance( ---------- vi_results: Dictionary Dictionary computed with `compute_variable_importance` - X : npt.NDArray[np.float64] - The covariate matrix. + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. @@ -995,11 +1001,15 @@ def plot_variable_importance( ------- axes: matplotlib axes """ + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) - indices = vi_results["indices"] - r2_mean = vi_results["r2_mean"] - r2_hdi = vi_results["r2_hdi"] - preds = vi_results["preds"] + indices = vi_results["indices"][submodels] + r2_mean = vi_results["r2_mean"][submodels] + r2_hdi = vi_results["r2_hdi"][submodels] + preds = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] samples = preds.shape[1] @@ -1016,9 +1026,7 @@ def plot_variable_importance( _, ax = plt.subplots(1, 1, figsize=figsize) if labels is None: - labels = vi_results["labels"] - - labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] + labels = vi_results["labels"][submodels] r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) @@ -1059,21 +1067,27 @@ def plot_variable_importance( def plot_scatter_submodels( vi_results: dict, func: Optional[Callable] = None, + submodels: Optional[Union[list[int], np.ndarray]] = None, grid: str = "long", - labels=None, + labels: Optional[list[str]] = None, figsize: Optional[tuple[float, float]] = None, plot_kwargs: Optional[dict[str, Any]] = None, - axes: Optional[plt.Axes] = None, -): + ax: Optional[plt.Axes] = None, +) -> list[plt.Axes]: """ Plot submodel's predictions against reference-model's predictions. Parameters ---------- - vi_results: Dictionary + vi_results : Dictionary Dictionary computed with `compute_variable_importance` func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number @@ -1092,20 +1106,23 @@ def plot_scatter_submodels( ------- axes: matplotlib axes """ - indices = vi_results["indices"] - preds = vi_results["preds"] + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) + + indices = vi_results["indices"][submodels] + preds = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] - if axes is None: - _, axes = _get_axes(grid, len(indices), True, True, figsize) + if ax is None: + _, ax = _get_axes(grid, len(indices), True, True, figsize) if plot_kwargs is None: plot_kwargs = {} if labels is None: - labels = vi_results["labels"] - - labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] + labels = vi_results["labels"][submodels] if func is not None: preds = func(preds) @@ -1114,8 +1131,8 @@ def plot_scatter_submodels( min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, x_label, ax in zip(preds, labels, axes.ravel()): - ax.plot( + for pred, x_label, axi in zip(preds, labels, ax.ravel()): + axi.plot( pred, preds_all, marker=plot_kwargs.get("marker_scatter", "."), @@ -1123,13 +1140,14 @@ def plot_scatter_submodels( color=plot_kwargs.get("color_scatter", "C0"), alpha=plot_kwargs.get("alpha_scatter", 0.1), ) - ax.set_xlabel(x_label) - ax.axline( + axi.set_xlabel(x_label) + axi.axline( [min_, min_], [max_, max_], color=plot_kwargs.get("color_ref", "0.5"), ls=plot_kwargs.get("ls_ref", "--"), ) + return ax def generate_sequences(n_vars, i_var, include): From 07f55d4a46563f874b0bd0a2cdae45882f3c7cdb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:48:57 -0300 Subject: [PATCH 13/53] [pre-commit.ci] pre-commit autoupdate (#201) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.0 → v0.8.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.0...v0.8.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8832b06..2c92b49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.8.1 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 3364ed3ab57095e3b24e7a9ae7946c134817e7f8 Mon Sep 17 00:00:00 2001 From: Alexandre Andorra Date: Wed, 4 Dec 2024 19:21:57 -0500 Subject: [PATCH 14/53] Check if Y is a shared var in rng_fn (#202) --- pymc_bart/bart.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 16a856c..94b91c3 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -25,6 +25,7 @@ from pymc.distributions.distribution import Distribution, _support_point from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sharedvar import TensorSharedVariable from .split_rules import SplitRule from .tree import Tree @@ -53,11 +54,16 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None + if isinstance(cls.Y, TensorSharedVariable): + Y = cls.Y.eval() + else: + Y = cls.Y + if not cls.all_trees: if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) + return np.full((size[0], Y.shape[0]), Y.mean()) else: - return np.full(cls.Y.shape[0], cls.Y.mean()) + return np.full(Y.shape[0], Y.mean()) else: if size is not None: shape = size[0] From bcdf77d8956ab3d9d523410506ca55ac7ed31c87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:10:01 -0300 Subject: [PATCH 15/53] [pre-commit.ci] pre-commit autoupdate (#203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.1 → v0.8.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.1...v0.8.3) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c92b49..fe00024 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.1 + rev: v0.8.3 hooks: - id: ruff args: ["--fix", "--output-format=full"] From d4e8cadaf84cde6417390dc18c06d83c9b9114c4 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Tue, 17 Dec 2024 10:39:00 -0300 Subject: [PATCH 16/53] bump release (#205) * bump release * fix zip and new args --- pymc_bart/__init__.py | 2 +- pymc_bart/bart.py | 2 +- pymc_bart/pgbart.py | 12 ++++++++---- requirements.txt | 2 +- tests/test_bart.py | 6 ++++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 18fe054..eee1881 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.7.1" +__version__ = "0.8.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 94b91c3..eb869d2 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -175,7 +175,7 @@ def get_moment(rv, size, *rv_inputs): return cls.get_moment(rv, size, *rv_inputs) cls.rv_op = bart_op - params = [X, Y, m, alpha, beta, split_prior] + params = [X, Y, m, alpha, beta] return super().__new__(cls, name, *params, **kwargs) @classmethod diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6a7e26e..1505f15 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -17,6 +17,7 @@ import numpy as np import numpy.typing as npt from numba import njit +from pymc.initial_point import PointType from pymc.model import Model, modelcontext from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared @@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915 num_particles: int = 10, batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, + initial_point: PointType | None = None, + compile_kwargs: dict | None = None, # pylint: disable=unused-argument ): model = modelcontext(model) - initial_values = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() if vars is None: vars = model.value_vars else: @@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915 self.m = self.bart.m self.response = self.bart.response - shape = initial_values[value_bart.name].shape + shape = initial_point[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] @@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915 self.num_particles = num_particles self.indices = list(range(1, num_particles)) - shared = make_shared_replacements(initial_values, vars, model) - self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) + shared = make_shared_replacements(initial_point, vars, model) + self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared) self.all_particles = [ [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) ] diff --git a/requirements.txt b/requirements.txt index ac9bd07..da634d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2, <=5.18 +pymc>=5.16.2, <=5.19.1 arviz>=0.18.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index c64811a..a003363 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -248,8 +248,10 @@ def test_categorical_model(separate_trees, split_rule): separate_trees=separate_trees, ) y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y) - idata = pm.sample(random_seed=3415, tune=300, draws=300) - idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True) + idata = pm.sample(tune=300, draws=300, random_seed=3415) + idata = pm.sample_posterior_predictive( + idata, predictions=True, extend_inferencedata=True, random_seed=3415 + ) # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() From 77116d1289c433849333c47bdccc786773da1f77 Mon Sep 17 00:00:00 2001 From: Alexandre Andorra Date: Thu, 19 Dec 2024 13:02:13 -0500 Subject: [PATCH 17/53] Patch for case when Y is a TensorVariable (#206) * add case tensor var for Y * Improve `isinstance` statement Co-authored-by: Osvaldo A Martin --------- Co-authored-by: Osvaldo A Martin --- pymc_bart/bart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index eb869d2..decb499 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -26,6 +26,7 @@ from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.sharedvar import TensorSharedVariable +from pytensor.tensor.variable import TensorVariable from .split_rules import SplitRule from .tree import Tree @@ -54,7 +55,7 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if isinstance(cls.Y, TensorSharedVariable): + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): Y = cls.Y.eval() else: Y = cls.Y From 1ec251b86831a36c480c200f541c2bfe1154ea32 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Fri, 20 Dec 2024 12:39:43 -0300 Subject: [PATCH 18/53] Fix bug with labels in variable importance, add reference line, remove deprecation warning (#207) * fix bug labels variable importance, add reference line * revert change --- pymc_bart/__init__.py | 4 +--- pymc_bart/utils.py | 32 ++++++++++++++------------------ pyproject.toml | 2 ++ 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index eee1881..440f7f2 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -19,7 +19,6 @@ from pymc_bart.utils import ( compute_variable_importance, plot_convergence, - plot_dependence, plot_ice, plot_pdp, plot_scatter_submodels, @@ -35,14 +34,13 @@ "SubsetSplitRule", "compute_variable_importance", "plot_convergence", - "plot_dependence", "plot_ice", "plot_pdp", "plot_scatter_submodels", "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.8.0" +__version__ = "0.8.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e10a511..d9738dd 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -137,22 +137,6 @@ def plot_convergence( return ax -def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument - """ - Partial dependence or individual conditional expectation plot. - """ - if kind == "pdp": - warnings.warn( - "This function has been deprecated. Use plot_pdp instead.", - FutureWarning, - ) - elif kind == "ice": - warnings.warn( - "This function has been deprecated. Use plot_ice instead.", - FutureWarning, - ) - - def plot_ice( bartrv: Variable, X: npt.NDArray[np.float64], @@ -307,6 +291,7 @@ def plot_pdp( var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, samples: int = 200, + ref_line: bool = True, random_seed: Optional[int] = None, sharey: bool = True, smooth: bool = True, @@ -347,6 +332,8 @@ def plot_pdp( Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 200 + ref_line : bool + If True a reference line is plotted at the mean of the partial dependence. Defaults to True. random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool @@ -402,6 +389,7 @@ def identity(x): count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) + null_pd = [] for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) @@ -413,6 +401,7 @@ def identity(x): new_x = fake_X[:, var] for s_i in range(shape): p_di = func(p_d[:, :, s_i]) + null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] @@ -442,6 +431,11 @@ def identity(x): count += 1 + if ref_line: + ref_val = sum(null_pd) / len(null_pd) + for ax_ in np.ravel(axes): + ax_.axhline(ref_val, color="0.7", linestyle="--") + fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15) return axes @@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 indices = least_important_vars[::-1] - labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]) + labels = np.array( + ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + ) vi_results = { "indices": np.asarray(indices), - "labels": labels[indices], + "labels": labels, "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, diff --git a/pyproject.toml b/pyproject.toml index bc94137..f8f3e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ line-length = 100 select = ["E", "F", "I", "PL", "UP", "W"] ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. + "PLR0913", #Too many arguments in function definition + ] [tool.ruff.lint.pylint] From 2f0b3aa1697096a62f307eaea9a374c6bb4ac977 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 23 Dec 2024 11:04:59 -0300 Subject: [PATCH 19/53] fix bug with shapes (#208) --- pymc_bart/utils.py | 2 +- tests/test_bart.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d9738dd..d9d5241 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -828,7 +828,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 r2_mean = np.zeros(n_vars) r2_hdi = np.zeros((n_vars, 2)) - preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) + preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) if method == "backward_VI": if fixed >= n_vars: diff --git a/tests/test_bart.py b/tests/test_bart.py index a003363..226d938 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -255,3 +255,4 @@ def test_categorical_model(separate_trees, split_rule): # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() + assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3) From 8b536b9950e6bd9ec7fb08261bf71977b1b15eb0 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 23 Dec 2024 11:12:11 -0300 Subject: [PATCH 20/53] Update __init__.py --- pymc_bart/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 440f7f2..361be83 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -40,7 +40,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.8.1" +__version__ = "0.8.2" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] From b84ba1c45222c089d2d58ab73066f234f08f033b Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 26 Dec 2024 21:39:55 +0100 Subject: [PATCH 21/53] Update MyPy 14 (#210) * move mypy config * some fixes * some fixes * some fixes * some fixes * some fixes * some fixes * remove reference np.float64 * remove unnesserary casting * fix type * fix import --- .pre-commit-config.yaml | 4 +-- mypy.ini | 15 --------- pymc_bart/bart.py | 6 ++-- pymc_bart/pgbart.py | 73 ++++++++++++++++++++++++----------------- pymc_bart/tree.py | 58 +++++++++++++++++++------------- pymc_bart/utils.py | 50 ++++++++++++++-------------- pyproject.toml | 17 ++++++++++ 7 files changed, 123 insertions(+), 100 deletions(-) delete mode 100644 mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe00024..4f55bc1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.8.4 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.0 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 56088d7..0000000 --- a/mypy.ini +++ /dev/null @@ -1,15 +0,0 @@ -[mypy] -files = pymc_bart/*.py -plugins = numpy.typing.mypy_plugin - -[mypy-matplotlib.*] -ignore_missing_imports = True - -[mypy-numba.*] -ignore_missing_imports = True - -[mypy-pymc.*] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index decb499..ac2be35 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -132,7 +132,7 @@ def __new__( alpha: float = 0.95, beta: float = 2.0, response: str = "constant", - split_prior: Optional[npt.NDArray[np.float64]] = None, + split_prior: Optional[npt.NDArray] = None, split_rules: Optional[list[SplitRule]] = None, separate_trees: Optional[bool] = False, **kwargs, @@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs): return mean -def preprocess_xy( - X: TensorLike, Y: TensorLike -) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 1505f15..014313a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -16,6 +16,8 @@ import numpy as np import numpy.typing as npt +import pymc as pm +import pytensor.tensor as pt from numba import njit from pymc.initial_point import PointType from pymc.model import Model, modelcontext @@ -120,15 +122,15 @@ class PGBART(ArrayStepShared): "tune": (bool, []), } - def __init__( # noqa: PLR0915 + def __init__( # noqa: PLR0912, PLR0915 self, - vars=None, # pylint: disable=redefined-builtin + vars: list[pm.Distribution] | None = None, num_particles: int = 10, batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, initial_point: PointType | None = None, - compile_kwargs: dict | None = None, # pylint: disable=unused-argument - ): + compile_kwargs: dict | None = None, + ) -> None: model = modelcontext(model) if initial_point is None: initial_point = model.initial_point() @@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915 else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) + + if vars is None: + raise ValueError("Unable to find variables to sample") + value_bart = vars[0] self.bart = model.values_to_rvs[value_bart].owner.op @@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float: return wei / wei.sum() def resample( - self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + self, particles: list[ParticleTree], normalized_weights: npt.NDArray ) -> list[ParticleTree]: """ Use systematic resample for all but the first particle @@ -347,7 +353,7 @@ def resample( return particles def get_particle_tree( - self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + self, particles: list[ParticleTree], normalized_weights: npt.NDArray ) -> tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree @@ -359,7 +365,7 @@ def get_particle_tree( return new_particle, new_particle.tree - def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]: + def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]: """ Systematic resampling. @@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None: particle.log_weight = new_likelihood @staticmethod - def competence(var, has_grad): + def competence(var: pm.Distribution, has_grad: bool) -> Competence: """PGBART is only suitable for BART distributions.""" dist = getattr(var.owner, "op", None) if isinstance(dist, BARTRV): @@ -406,12 +412,12 @@ def competence(var, has_grad): class RunningSd: """Welford's online algorithm for computing the variance/standard deviation""" - def __init__(self, shape: tuple) -> None: + def __init__(self, shape: tuple[int, ...]) -> None: self.count = 0 # number of data points self.mean = np.zeros(shape) # running mean self.m_2 = np.zeros(shape) # running second moment - def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: + def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]: self.count = self.count + 1 self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value) return fast_mean(std) @@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray @njit def _update( count: int, - mean: npt.NDArray[np.float64], - m_2: npt.NDArray[np.float64], - new_value: npt.NDArray[np.float64], -) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]: + mean: npt.NDArray, + m_2: npt.NDArray, + new_value: npt.NDArray, +) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]: delta = new_value - mean mean += delta / count delta2 = new_value - mean @@ -434,7 +440,7 @@ def _update( class SampleSplittingVariable: - def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None: + def __init__(self, alpha_vec: npt.NDArray) -> None: """ Sample splitting variables proportional to `alpha_vec`. @@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d def draw_leaf_value( - y_mu_pred: npt.NDArray[np.float64], - x_mu: npt.NDArray[np.float64], + y_mu_pred: npt.NDArray, + x_mu: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], + norm: npt.NDArray, shape: int, response: str, -) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """Draw Gaussian distributed leaf values.""" linear_params = None - mu_mean = np.empty(shape) + mu_mean: npt.NDArray if y_mu_pred.size == 0: return np.zeros(shape), linear_params @@ -571,7 +577,7 @@ def draw_leaf_value( @njit -def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: +def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]: """Use Numba to speed up the computation of the mean.""" if ari.ndim == 1: count = ari.shape[0] @@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float @njit def fast_linear_fit( - x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], + x: npt.NDArray, + y: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], -) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]: + norm: npt.NDArray, +) -> tuple[npt.NDArray, list[npt.NDArray]]: n = len(x) y = y / m + np.expand_dims(norm, axis=1) @@ -678,17 +684,17 @@ def update(self): @njit def inverse_cdf( - single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64] + single_uniform: npt.NDArray, normalized_weights: npt.NDArray ) -> npt.NDArray[np.int_]: """ Inverse CDF algorithm for a finite distribution. Parameters ---------- - single_uniform: npt.NDArray[np.float64] + single_uniform: npt.NDArray Ordered points in [0,1] - normalized_weights: npt.NDArray[np.float64]) + normalized_weights: npt.NDArray) Normalized weights Returns @@ -711,7 +717,7 @@ def inverse_cdf( @njit -def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]: +def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray: """ Jitter duplicated values. """ @@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray @njit -def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_: +def are_whole_number(array: npt.NDArray) -> np.bool_: """Check if all values in array are whole numbers""" return np.all(np.mod(array[~np.isnan(array)], 1) == 0) -def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin +def logp( + point, + out_vars: list[pm.Distribution], + vars: list[pm.Distribution], + shared: list[pt.TensorVariable], +): """Compile PyTensor function of the model and the input and output variables. Parameters diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 7655175..61e5050 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -28,7 +28,7 @@ class Node: Attributes ---------- - value : npt.NDArray[np.float64] + value : npt.NDArray idx_data_points : Optional[npt.NDArray[np.int_]] idx_split_variable : int linear_params: Optional[list[float]] = None @@ -38,11 +38,11 @@ class Node: def __init__( self, - value: npt.NDArray[np.float64] = np.array([-1.0]), + value: npt.NDArray = np.array([-1.0]), nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> None: self.value = value self.nvalue = nvalue @@ -53,11 +53,11 @@ def __init__( @classmethod def new_leaf_node( cls, - value: npt.NDArray[np.float64], + value: npt.NDArray, nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> "Node": return cls( value=value, @@ -101,7 +101,7 @@ class Tree: The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - output: Optional[npt.NDArray[np.float64]] + output: Optional[npt.NDArray] Array of shape number of observations, shape split_rules : list[SplitRule] List of SplitRule objects, one per column in input data. @@ -122,7 +122,7 @@ class Tree: def __init__( self, tree_structure: dict[int, Node], - output: npt.NDArray[np.float64], + output: npt.NDArray, split_rules: list[SplitRule], idx_leaf_nodes: Optional[list[int]] = None, ) -> None: @@ -134,7 +134,7 @@ def __init__( @classmethod def new_tree( cls, - leaf_node_value: npt.NDArray[np.float64], + leaf_node_value: npt.NDArray, idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, @@ -190,7 +190,7 @@ def grow_leaf_node( self, current_node: Node, selected_predictor: int, - split_value: npt.NDArray[np.float64], + split_value: npt.NDArray, index_leaf_node: int, ) -> None: current_node.value = split_value @@ -222,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]: if node.is_split_node(): yield node.idx_split_variable - def _predict(self) -> npt.NDArray[np.float64]: + def _predict(self) -> npt.NDArray: output = self.output if self.idx_leaf_nodes is not None: @@ -233,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]: def predict( self, - x: npt.NDArray[np.float64], + x: npt.NDArray, excluded: Optional[list[int]] = None, shape: int = 1, - ) -> npt.NDArray[np.float64]: + ) -> npt.NDArray: """ Predict output of tree for an (un)observed point x. Parameters ---------- - x : npt.NDArray[np.float64] + x : npt.NDArray Unobserved point excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns ------- - npt.NDArray[np.float64] + npt.NDArray Value of the leaf value where the unobserved point lies. """ if excluded is None: @@ -259,16 +259,16 @@ def predict( def _traverse_tree( self, - X: npt.NDArray[np.float64], + X: npt.NDArray, excluded: Optional[list[int]] = None, shape: Union[int, tuple[int, ...]] = 1, - ) -> npt.NDArray[np.float64]: + ) -> npt.NDArray: """ Traverse the tree starting from the root node given an (un)observed point. Parameters ---------- - X : npt.NDArray[np.float64] + X : npt.NDArray (Un)observed point(s) node_index : int Index of the node to start the traversal from @@ -279,14 +279,16 @@ def _traverse_tree( Returns ------- - npt.NDArray[np.float64] + npt.NDArray Leaf node value or mean of leaf node values """ x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1] nd_dims = (...,) + (None,) * len(x_shape) - stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state + stack: list[tuple[int, npt.NDArray, int]] = [ + (0, np.ones(x_shape), 0) + ] # (node_index, weight, idx_split_variable) initial state p_d = ( np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape) ) @@ -309,9 +311,19 @@ def _traverse_tree( ) if excluded is not None and idx_split_variable in excluded: prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue - stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable)) stack.append( - (right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable) + ( + left_node_index, + weights * prop_nvalue_left, + idx_split_variable, + ) + ) + stack.append( + ( + right_node_index, + weights * (1 - prop_nvalue_left), + idx_split_variable, + ) ) else: to_left = ( @@ -328,14 +340,14 @@ def _traverse_tree( return p_d def _traverse_leaf_values( - self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int + self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : list[npt.NDArray[np.float64]] + leaf_values : list[npt.NDArray] node_index : int """ node = self.get_node(node_index) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d9d5241..58d14b8 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -17,7 +17,7 @@ from .tree import Tree -TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable] +TensorLike = Union[npt.NDArray, pt.TensorVariable] def _sample_posterior( @@ -27,7 +27,7 @@ def _sample_posterior( size: Optional[Union[int, tuple[int, ...]]] = None, excluded: Optional[list[int]] = None, shape: int = 1, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray: """ Generate samples from the BART-posterior. @@ -139,8 +139,8 @@ def plot_convergence( def plot_ice( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, @@ -165,9 +165,9 @@ def plot_ice( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. @@ -283,8 +283,8 @@ def identity(x): def plot_pdp( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, xs_interval: str = "quantiles", xs_values: Optional[Union[int, list[float]]] = None, var_idx: Optional[list[int]] = None, @@ -310,9 +310,9 @@ def plot_pdp( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. xs_interval : str Method used to compute the values X used to evaluate the predicted function. "linear", @@ -526,14 +526,14 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize): def _prepare_plot_data( - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, xs_interval: str = "quantiles", xs_values: Optional[Union[int, list[float]]] = None, var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, ) -> tuple[ - npt.NDArray[np.float64], + npt.NDArray, list[str], str, list[int], @@ -619,10 +619,10 @@ def _prepare_plot_data( def _create_pdp_data( - X: npt.NDArray[np.float64], + X: npt.NDArray, xs_interval: str, xs_values: Optional[Union[int, list[float]]] = None, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray: """ Create data for partial dependence plot. @@ -637,7 +637,7 @@ def _create_pdp_data( Returns ------- - npt.NDArray[np.float64] + npt.NDArray A 2D array for the fake_X data. """ if xs_interval == "insample": @@ -654,8 +654,8 @@ def _create_pdp_data( def _smooth_mean( - new_x: npt.NDArray[np.float64], - p_di: npt.NDArray[np.float64], + new_x: npt.NDArray, + p_di: npt.NDArray, kind: str = "pdp", smooth_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[np.ndarray, np.ndarray]: @@ -701,7 +701,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non ---------- idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will @@ -767,7 +767,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: az.InferenceData, bartrv: Variable, - X: npt.NDArray[np.float64], + X: npt.NDArray, method: str = "VI", fixed: int = 0, samples: int = 50, @@ -782,7 +782,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. method : str Method used to rank variables. Available options are "VI" (default), "backward" @@ -826,9 +826,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 else: labels = np.arange(n_vars).astype(str) - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) - preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) + r2_mean: npt.NDArray = np.zeros(n_vars) + r2_hdi: npt.NDArray = np.zeros((n_vars, 2)) + preds: npt.NDArray = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) if method == "backward_VI": if fixed >= n_vars: @@ -848,7 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values ) - subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))] + subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))] subsets.append(None) # type: ignore if method == "backward_VI": diff --git a/pyproject.toml b/pyproject.toml index f8f3e7a..4a2273d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,20 @@ exclude_lines = [ isort = 1 black = 1 pyupgrade = 1 + + +[tool.mypy] +files = "pymc_bart/*.py" +plugins = "numpy.typing.mypy_plugin" + +[tool.mypy-matplotlib] +ignore_missing_imports = true + +[tool.mypy-numba] +ignore_missing_imports = true + +[tool.mypy-pymc] +ignore_missing_imports = true + +[tool.mypy-scipy] +ignore_missing_imports = true From 139aeacc360914ef21cf42db2534a4512437198c Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 28 Dec 2024 15:17:47 -0300 Subject: [PATCH 22/53] Automatic Changelog (#213) * automatic changelog * add changelog to docs --- .github/workflows/post-release.yml | 19 +++++++++++++++++++ CHANGELOG.md | 0 docs/changelog.rst | 5 +++++ docs/index.rst | 12 +++++++----- 4 files changed, 31 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/post-release.yml create mode 100644 CHANGELOG.md create mode 100644 docs/changelog.rst diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml new file mode 100644 index 0000000..5526a27 --- /dev/null +++ b/.github/workflows/post-release.yml @@ -0,0 +1,19 @@ +name: Post-release +on: + release: + types: [published, released] + workflow_dispatch: + +jobs: + changelog: + name: Update changelog + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + - uses: rhysd/changelog-from-release/action@v3 + with: + file: CHANGELOG.md + github_token: ${{ secrets.GITHUB_TOKEN }} + commit_summary_template: 'update changelog for %s changes' diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 0000000..f83d445 --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,5 @@ +Changelog +********* + +.. include:: ../CHANGELOG.md + :parser: myst_parser.sphinx_ diff --git a/docs/index.rst b/docs/index.rst index 4b1dd0e..c73500c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -93,10 +93,12 @@ Contents :maxdepth: 2 examples - api_reference -Indices -======= +References +========== + +.. toctree:: + :maxdepth: 1 -* :ref:`genindex` -* :ref:`modindex` + api_reference + changelog From 3bad2c68df2e766ff65091ad64d1fe8472eca143 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 28 Dec 2024 18:37:17 -0300 Subject: [PATCH 23/53] Update index.rst --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index c73500c..78a59fb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ interpretation of those models and perform variable selection. Installation ============ -PyMC-BART requires a working Python interpreter (3.8+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. +PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge. From 064457e34d3041bc3886b66a2707b94f5554aac4 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sun, 29 Dec 2024 08:11:29 -0300 Subject: [PATCH 24/53] Adds get_variable_inclusion function (#214) * add get_variable_inclusion function * add elements to API reference --- docs/api_reference.rst | 2 +- pymc_bart/__init__.py | 2 ++ pymc_bart/utils.py | 68 +++++++++++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 93afde1..b6fb8a5 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART. ============================= .. automodule:: pymc_bart - :members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule + :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 361be83..f4a1f7a 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -18,6 +18,7 @@ from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( compute_variable_importance, + get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, @@ -33,6 +34,7 @@ "OneHotSplitRule", "SubsetSplitRule", "compute_variable_importance", + "get_variable_inclusion", "plot_convergence", "plot_ice", "plot_pdp", diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 58d14b8..df8f76f 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -693,6 +693,50 @@ def _smooth_mean( return x_data, y_data +def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): + """ + Get the normalized variable inclusion from BART model. + + Parameters + ---------- + idata : InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : npt.NDArray + The covariate matrix. + labels : Optional[list[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + to_kulprit : bool + If True, the function will return a list of list with the variables names. + This list can be passed as a path to Kulprit's project method. Defaults to False. + Returns + ------- + VI_norm : npt.NDArray + Normalized variable inclusion. + labels : list[str] + List of the names of the covariates. + """ + VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + VI_norm = VIs / VIs.sum() + idxs = np.argsort(VI_norm) + + indices = idxs[::-1] + n_vars = len(indices) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + + if labels is None: + labels = np.arange(n_vars).astype(str) + + label_list = labels.to_list() + + if to_kulprit: + return [label_list[:idx] for idx in range(n_vars)] + else: + return VI_norm[indices], label_list + + def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): """ Plot normalized variable inclusion from BART model. @@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non Returns ------- - idxs: indexes of the covariates from higher to lower relative importance axes: matplotlib axes """ if plot_kwargs is None: plot_kwargs = {} - VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values - VIs = VIs / VIs.sum() - idxs = np.argsort(VIs) - - indices = idxs[::-1] - n_vars = len(indices) - - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns + VI_norm, labels = get_variable_inclusion(idata, X, labels) + n_vars = len(labels) - if labels is None: - labels = np.arange(n_vars).astype(str) - - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] ticks = np.arange(n_vars, dtype=int) @@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) + ax.axhline(1 / n_vars, color="0.5", linestyle="--") ax.plot( - VIs[indices], + VI_norm, color=plot_kwargs.get("color", "k"), marker=plot_kwargs.get("marker", "o"), ls=plot_kwargs.get("ls", "-"), ) ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0)) - - ax.axhline(1 / n_vars, color="0.5", linestyle="--") ax.set_ylim(0, 1) - return idxs, ax + return ax def compute_variable_importance( # noqa: PLR0915 PLR0912 From cd5dfbe4e09e3e450b384eacbc2d3292734ea9e7 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sun, 29 Dec 2024 08:11:49 -0300 Subject: [PATCH 25/53] refactor rng_fn method (#212) --- pymc_bart/bart.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index ac2be35..5114b6e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -55,12 +55,12 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): - Y = cls.Y.eval() - else: - Y = cls.Y - if not cls.all_trees: + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): + Y = cls.Y.eval() + else: + Y = cls.Y + if size is not None: return np.full((size[0], Y.shape[0]), Y.mean()) else: From 0d4d6f55a077f05fe93c5973d9512e4cafa374f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 12:18:59 +0100 Subject: [PATCH 26/53] [pre-commit.ci] pre-commit autoupdate (#215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.4 → v0.8.6](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.4...v0.8.6) - [github.com/pre-commit/mirrors-mypy: v1.14.0 → v1.14.1](https://github.com/pre-commit/mirrors-mypy/compare/v1.14.0...v1.14.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f55bc1..8a5992a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.8.6 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.14.0 + rev: v1.14.1 hooks: - id: mypy args: [--ignore-missing-imports] From 44c787cc7b2a8473ca6c1f9fb62171004a27167f Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 11 Feb 2025 07:54:38 +0100 Subject: [PATCH 27/53] Fix docs by adding path of config (#217) * pre-commit update * add conf.py path --- .pre-commit-config.yaml | 4 ++-- .readthedocs.yaml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a5992a..1bc3739 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.6 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.14.1 + rev: v1.15.0 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6e5cef0..0ce9313 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,6 +1,9 @@ # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: docs/conf.py build: os: ubuntu-20.04 From 16a78df60b874005ee3b6dd06a0d2a2e892f0946 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:38:37 +0200 Subject: [PATCH 28/53] [pre-commit.ci] pre-commit autoupdate (#219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.9.6 → v0.9.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.6...v0.9.9) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bc3739..6a3b804 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.6 + rev: v0.9.9 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 73813e1308163d0571676600d1a2af2b45719592 Mon Sep 17 00:00:00 2001 From: Alexandre Andorra Date: Mon, 10 Mar 2025 03:23:11 -0400 Subject: [PATCH 29/53] Enhance `plot_pdp` and fix `plot_scatter_submodels` (#218) * Add YML env files * Expand scatter_submodels to categorical likelihood * Add softmax option to plot_pdp * Remove comments * Use func for softmax * handle func upstream * move func upstream * ensure p_d is an array --------- Co-authored-by: aloctavodia --- env-dev.yml | 23 +++++++++++ env.yml | 14 +++++++ pymc_bart/utils.py | 96 +++++++++++++++++++++++++++++++--------------- 3 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 env-dev.yml create mode 100644 env.yml diff --git a/env-dev.yml b/env-dev.yml new file mode 100644 index 0000000..1e28429 --- /dev/null +++ b/env-dev.yml @@ -0,0 +1,23 @@ +name: pymc-bart-dev +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + # Development dependencies + - pytest>=4.4.0 + - pytest-cov>=2.6.1 + - click==8.0.4 + - pylint==2.17.4 + - pre-commit + - black + - isort + - flake8 + - pip + - pip: + - -e . diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..bd814ae --- /dev/null +++ b/env.yml @@ -0,0 +1,14 @@ +name: pymc-bart +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + - pip + - pip: + - pymc-bart diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index df8f76f..3ba6e58 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -254,13 +254,13 @@ def identity(x): ) new_x = fake_X[:, var] - p_d = np.array(y_pred) + p_d = func(np.array(y_pred)) for s_i in range(shape): if centered: - p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None]) + p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None] else: - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] if var in var_discrete: axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean) axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha) @@ -393,14 +393,17 @@ def identity(x): for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) - p_d = _sample_posterior( - all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + p_d = func( + _sample_posterior( + all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + ) ) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] for s_i in range(shape): - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) @@ -1125,8 +1128,11 @@ def plot_scatter_submodels( plot_kwargs : dict Additional keyword arguments for the plot. Defaults to None. Valid keys are: - - color_ref: matplotlib valid color for the 45 degree line + - marker_scatter: matplotlib valid marker for the scatter plot - color_scatter: matplotlib valid color for the scatter plot + - alpha_scatter: matplotlib valid alpha for the scatter plot + - color_ref: matplotlib valid color for the 45 degree line + - ls_ref: matplotlib valid linestyle for the reference line axes : axes Matplotlib axes. @@ -1140,41 +1146,69 @@ def plot_scatter_submodels( submodels = np.sort(submodels) indices = vi_results["indices"][submodels] - preds = vi_results["preds"][submodels] + preds_sub = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] + if labels is None: + labels = vi_results["labels"][submodels] + + # handle categorical regression case: + n_cats = None + if preds_all.ndim > 2: + n_cats = preds_all.shape[-1] + indices = np.tile(indices, n_cats) + if ax is None: _, ax = _get_axes(grid, len(indices), True, True, figsize) if plot_kwargs is None: plot_kwargs = {} - if labels is None: - labels = vi_results["labels"][submodels] - if func is not None: - preds = func(preds) + preds_sub = func(preds_sub) preds_all = func(preds_all) - min_ = min(np.min(preds), np.min(preds_all)) - max_ = max(np.max(preds), np.max(preds_all)) - - for pred, x_label, axi in zip(preds, labels, ax.ravel()): - axi.plot( - pred, - preds_all, - marker=plot_kwargs.get("marker_scatter", "."), - ls="", - color=plot_kwargs.get("color_scatter", "C0"), - alpha=plot_kwargs.get("alpha_scatter", 0.1), - ) - axi.set_xlabel(x_label) - axi.axline( - [min_, min_], - [max_, max_], - color=plot_kwargs.get("color_ref", "0.5"), - ls=plot_kwargs.get("ls_ref", "--"), - ) + min_ = min(np.min(preds_sub), np.min(preds_all)) + max_ = max(np.max(preds_sub), np.max(preds_all)) + + # handle categorical regression case: + if n_cats is not None: + i = 0 + for cat in range(n_cats): + for pred_sub, x_label in zip(preds_sub, labels): + ax[i].plot( + pred_sub[..., cat], + preds_all[..., cat], + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", f"C{cat}"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}") + ax[i].axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) + i += 1 + else: + for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()): + axi.plot( + pred_sub, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + axi.set(xlabel=x_label, ylabel="ref model") + axi.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) return ax From 7986e2325aa4f20b558d5c223b685f0fef6bb986 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 10 Mar 2025 09:38:58 +0200 Subject: [PATCH 30/53] bump release (#220) --- pymc_bart/__init__.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index f4a1f7a..ed1a29a 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.8.2" +__version__ = "0.9.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/requirements.txt b/requirements.txt index da634d4..785de62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2, <=5.19.1 +pymc>=5.16.2, <=5.20.1 arviz>=0.18.0 numba matplotlib From 6b756c91a64fdec1d694aba37480dbadf9c63f90 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 10 Mar 2025 07:40:45 +0000 Subject: [PATCH 31/53] update changelog for 0.9.0 changes This commit was created by changelog-from-release in 'Post-release' CI workflow --- CHANGELOG.md | 449 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 449 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e69de29..6577689 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -0,0 +1,449 @@ + +# [0.9.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.0) - 2025-03-10 + +## What's Changed +* Update MyPy 14 by [@juanitorduz](https://github.com/juanitorduz) in [#210](https://github.com/pymc-devs/pymc-bart/pull/210) +* Automatic Changelog by [@aloctavodia](https://github.com/aloctavodia) in [#213](https://github.com/pymc-devs/pymc-bart/pull/213) +* Adds get_variable_inclusion function by [@aloctavodia](https://github.com/aloctavodia) in [#214](https://github.com/pymc-devs/pymc-bart/pull/214) +* Refactor rng_fn method by [@aloctavodia](https://github.com/aloctavodia) in [#212](https://github.com/pymc-devs/pymc-bart/pull/212) +* Fix docs by adding path of config by [@juanitorduz](https://github.com/juanitorduz) in [#217](https://github.com/pymc-devs/pymc-bart/pull/217) +* Enhance `plot_pdp` and fix `plot_scatter_submodels` by [@AlexAndorra](https://github.com/AlexAndorra) in [#218](https://github.com/pymc-devs/pymc-bart/pull/218) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 + +[Changes][0.9.0] + + + +# [0.8.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.2) - 2024-12-23 + +## What's Changed +* Compute_variable_importance: fix bug with non-default shapes by [@aloctavodia](https://github.com/aloctavodia) in [#208](https://github.com/pymc-devs/pymc-bart/pull/208) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 + +[Changes][0.8.2] + + + +# [0.8.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.1) - 2024-12-20 + +## What's Changed +* Patch for case when Y is a TensorVariable by [@AlexAndorra](https://github.com/AlexAndorra) in [#206](https://github.com/pymc-devs/pymc-bart/pull/206) +* Fix bug with labels in variable importance, add reference line, remove deprecation warning by [@aloctavodia](https://github.com/aloctavodia) in [#207](https://github.com/pymc-devs/pymc-bart/pull/207) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1 + +[Changes][0.8.1] + + + +# [0.8.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.0) - 2024-12-17 + +## What's Changed + +* Add new vi plots by [@aloctavodia](https://github.com/aloctavodia) in [#196](https://github.com/pymc-devs/pymc-bart/pull/196) +* Allows plotting a subset of the variables once the variable's importance has been computed by [@aloctavodia](https://github.com/aloctavodia) in [#200](https://github.com/pymc-devs/pymc-bart/pull/200) +* Enable passing `Y` as a `SharedVariable` to `pm.Bart` by [@AlexAndorra](https://github.com/AlexAndorra) in [#202](https://github.com/pymc-devs/pymc-bart/pull/202) +* Improve docs, aesthetics and functionality by [@aloctavodia](https://github.com/aloctavodia) in [#198](https://github.com/pymc-devs/pymc-bart/pull/198) + + +## New Contributors +* [@AlexAndorra](https://github.com/AlexAndorra) made their first contribution in [#202](https://github.com/pymc-devs/pymc-bart/pull/202) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0 + +[Changes][0.8.0] + + + +# [0.7.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.1) - 2024-11-07 + +## What's Changed +* Conform to recent changes in pymc by [@aloctavodia](https://github.com/aloctavodia) in [#194](https://github.com/pymc-devs/pymc-bart/pull/194) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1 + +[Changes][0.7.1] + + + +# [0.7.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.0) - 2024-09-05 + +## What's Changed +* Allow Y to be a tensor by [@aloctavodia](https://github.com/aloctavodia) in [#180](https://github.com/pymc-devs/pymc-bart/pull/180) +* improve plot_variable_importance by [@aloctavodia](https://github.com/aloctavodia) in [#182](https://github.com/pymc-devs/pymc-bart/pull/182) +* move x_angle to plot_kwargs by [@aloctavodia](https://github.com/aloctavodia) in [#185](https://github.com/pymc-devs/pymc-bart/pull/185) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0 + +[Changes][0.7.0] + + + +# [0.6.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.6.0) - 2024-08-16 + +## What's Changed +* Add categorical example by [@PabloGGaray](https://github.com/PabloGGaray) in [#167](https://github.com/pymc-devs/pymc-bart/pull/167) +* Fix np.float_ type by [@juanitorduz](https://github.com/juanitorduz) in [#171](https://github.com/pymc-devs/pymc-bart/pull/171) +* Support Polars by [@aloctavodia](https://github.com/aloctavodia) in [#179](https://github.com/pymc-devs/pymc-bart/pull/179) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0 + +[Changes][0.6.0] + + + +# [0.5.14](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.14) - 2024-05-14 + +## What's Changed +* Less than equal PyMC Version by [@juanitorduz](https://github.com/juanitorduz) in [#164](https://github.com/pymc-devs/pymc-bart/pull/164) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14 + +[Changes][0.5.14] + + + +# [0.5.13](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.13) - 2024-05-13 + +## What's Changed +* Update pymc version requirements.txt by [@juanitorduz](https://github.com/juanitorduz) in [#163](https://github.com/pymc-devs/pymc-bart/pull/163) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13 + +[Changes][0.5.13] + + + +# [0.5.12](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.12) - 2024-04-18 + +## What's Changed +* Unpin numpy by [@maresb](https://github.com/maresb) in [#156](https://github.com/pymc-devs/pymc-bart/pull/156) +* Resolve deprecation warning for `pytensor`'s `Variable` by [@RyanAugust](https://github.com/RyanAugust) in [#159](https://github.com/pymc-devs/pymc-bart/pull/159) + +## New Contributors +* [@RyanAugust](https://github.com/RyanAugust) made their first contribution in [#159](https://github.com/pymc-devs/pymc-bart/pull/159) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12 + +[Changes][0.5.12] + + + +# [0.5.11](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.11) - 2024-03-15 + +## What's Changed +* Add citation file by [@PabloGGaray](https://github.com/PabloGGaray) in [#151](https://github.com/pymc-devs/pymc-bart/pull/151) +* Rename moment to support_point by [@PabloGGaray](https://github.com/PabloGGaray) in [#154](https://github.com/pymc-devs/pymc-bart/pull/154) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11 + +[Changes][0.5.11] + + + +# [0.5.10](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.10) - 2024-03-14 + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.10 + +[Changes][0.5.10] + + + +# [0.5.9](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.9) - 2024-03-14 + +## What's Changed +* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140) +* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141) +* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +## New Contributors +* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.9 + +[Changes][0.5.9] + + + +# [0.5.8](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.8) - 2024-03-14 + +## What's Changed +* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140) +* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141) +* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + + +## New Contributors +* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8 + +[Changes][0.5.8] + + + +# [0.5.7](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.7) - 2023-12-29 + +## What's Changed +* Properly handle nans when jittering by [@aloctavodia](https://github.com/aloctavodia) in [#136](https://github.com/pymc-devs/pymc-bart/pull/136) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7 + +[Changes][0.5.7] + + + +# [0.5.6](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.6) - 2023-12-23 + +## What's Changed +* Fix bug in plot_ice, and clean docstring of plot_ice and plot_pdp by [@aloctavodia](https://github.com/aloctavodia) in [#135](https://github.com/pymc-devs/pymc-bart/pull/135) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6 + +[Changes][0.5.6] + + + +# [0.5.5](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.5) - 2023-12-22 + +## What's Changed +* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129) +* link GitHub icon to pymc-bart repo by [@aloctavodia](https://github.com/aloctavodia) in [#131](https://github.com/pymc-devs/pymc-bart/pull/131) +* VI remove unnecessary evaluations for the backward method by [@aloctavodia](https://github.com/aloctavodia) in [#132](https://github.com/pymc-devs/pymc-bart/pull/132) +* jitter only arrays of whole numbers by [@aloctavodia](https://github.com/aloctavodia) in [#133](https://github.com/pymc-devs/pymc-bart/pull/133) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.5 + +[Changes][0.5.5] + + + +# [0.5.4](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.4) - 2023-11-21 + +## What's Changed +* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4 + +[Changes][0.5.4] + + + +# [0.5.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.3) - 2023-11-18 + +## What's Changed +* improve variable importance computation by adding backward method by [@aloctavodia](https://github.com/aloctavodia) in [#125](https://github.com/pymc-devs/pymc-bart/pull/125) +* set new paths to notebooks by [@aloctavodia](https://github.com/aloctavodia) in [#126](https://github.com/pymc-devs/pymc-bart/pull/126) +* fix case examples by [@aloctavodia](https://github.com/aloctavodia) in [#127](https://github.com/pymc-devs/pymc-bart/pull/127) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3 + +[Changes][0.5.3] + + + +# [0.5.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.2) - 2023-10-27 + +## What's Changed +* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108) +* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107) +* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109) +* Add issue templates by [@PabloGGaray](https://github.com/PabloGGaray) in [#113](https://github.com/pymc-devs/pymc-bart/pull/113) +* Add conda option by [@PabloGGaray](https://github.com/PabloGGaray) in [#114](https://github.com/pymc-devs/pymc-bart/pull/114) +* fix split_prior bug by [@aloctavodia](https://github.com/aloctavodia) in [#115](https://github.com/pymc-devs/pymc-bart/pull/115) +* Add logo by [@aloctavodia](https://github.com/aloctavodia) in [#116](https://github.com/pymc-devs/pymc-bart/pull/116) +* clean logo by [@aloctavodia](https://github.com/aloctavodia) in [#117](https://github.com/pymc-devs/pymc-bart/pull/117) +* Add plot_ice to API description on the webpage by [@PabloGGaray](https://github.com/PabloGGaray) in [#119](https://github.com/pymc-devs/pymc-bart/pull/119) +* Better handling of discrete variables and other minor fixes by [@aloctavodia](https://github.com/aloctavodia) in [#121](https://github.com/pymc-devs/pymc-bart/pull/121) + + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...0.5.2 + +[Changes][0.5.2] + + + +# [O.5.1](https://github.com/pymc-devs/pymc-bart/releases/tag/O.5.1) - 2023-07-12 + +## What's Changed +* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108) +* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107) +* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1 + +[Changes][O.5.1] + + + +# [0.5.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.0) - 2023-07-10 + +## What's Changed +* Add pre-commit hooks by [@juanitorduz](https://github.com/juanitorduz) in [#75](https://github.com/pymc-devs/pymc-bart/pull/75) +* Add mypy init by [@juanitorduz](https://github.com/juanitorduz) in [#78](https://github.com/pymc-devs/pymc-bart/pull/78) +* Do not store index at each node. by [@howsiyu](https://github.com/howsiyu) in [#80](https://github.com/pymc-devs/pymc-bart/pull/80) +* Add linear response [@juanitorduz](https://github.com/juanitorduz) in [#79](https://github.com/pymc-devs/pymc-bart/pull/79) +* Do weighted mean when pruning by [@aloctavodia](https://github.com/aloctavodia) in [#83](https://github.com/pymc-devs/pymc-bart/pull/83) +* Implement fast version of pdp by [@aloctavodia](https://github.com/aloctavodia) in [#85](https://github.com/pymc-devs/pymc-bart/pull/85) +* Add error bars to variable importance by [@aloctavodia](https://github.com/aloctavodia) in [#90](https://github.com/pymc-devs/pymc-bart/pull/90) +* Compute running variance for leaf nodes by [@aloctavodia](https://github.com/aloctavodia) in [#91](https://github.com/pymc-devs/pymc-bart/pull/91) +* Improve doc style and add missing examples by [@aloctavodia](https://github.com/aloctavodia) in [#92](https://github.com/pymc-devs/pymc-bart/pull/92) +* Make the Repo more welcoming with a clear title by [@juanitorduz](https://github.com/juanitorduz) in [#94](https://github.com/pymc-devs/pymc-bart/pull/94) +* Improve docstrings new alpha and beta parameters by [@juanitorduz](https://github.com/juanitorduz) in [#95](https://github.com/pymc-devs/pymc-bart/pull/95) +* Allow different splitting rules by [@velochy](https://github.com/velochy) in [#96](https://github.com/pymc-devs/pymc-bart/pull/96) +* Allow training separate tree structures if training multiple trees by [@velochy](https://github.com/velochy) in [#98](https://github.com/pymc-devs/pymc-bart/pull/98) + +## New Contributors +* [@howsiyu](https://github.com/howsiyu) made their first contribution in [#80](https://github.com/pymc-devs/pymc-bart/pull/80) +* [@velochy](https://github.com/velochy) made their first contribution in [#96](https://github.com/pymc-devs/pymc-bart/pull/96) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0 + +[Changes][0.5.0] + + + +# [0.4.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.4.0) - 2023-03-17 + +## What's Changed +* fig bug systematic resampling and add func argument by [@aloctavodia](https://github.com/aloctavodia) in [#61](https://github.com/pymc-devs/pymc-bart/pull/61) and [#66](https://github.com/pymc-devs/pymc-bart/pull/66) +* add tests for individual functions/methods in PGBART by [@aloctavodia](https://github.com/aloctavodia) in [#64](https://github.com/pymc-devs/pymc-bart/pull/64) +* Modify resampling schema and refactor by [@aloctavodia](https://github.com/aloctavodia) in [#65](https://github.com/pymc-devs/pymc-bart/pull/65) +* add plot_convergence by [@aloctavodia](https://github.com/aloctavodia) in [#67](https://github.com/pymc-devs/pymc-bart/pull/67) and [@aloctavodia](https://github.com/aloctavodia) in [#68](https://github.com/pymc-devs/pymc-bart/pull/68) +* Improve plot_dependence by [@PabloGGaray](https://github.com/PabloGGaray) in [#70](https://github.com/pymc-devs/pymc-bart/pull/70) and [@aloctavodia](https://github.com/aloctavodia) in [#71](https://github.com/pymc-devs/pymc-bart/pull/71) and in [#73](https://github.com/pymc-devs/pymc-bart/pull/73) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0 + +[Changes][0.4.0] + + + +# [0.3.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.2) - 2023-02-03 + +## What's Changed +* Refactor and [@njit](https://github.com/njit) on methods by [@fjloyola](https://github.com/fjloyola) in [#54](https://github.com/pymc-devs/pymc-bart/pull/54) +* Fix shape error [@aloctavodia](https://github.com/aloctavodia) in [#57](https://github.com/pymc-devs/pymc-bart/pull/57) and [#59](https://github.com/pymc-devs/pymc-bart/pull/59) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2 + +[Changes][0.3.2] + + + +# [0.3.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.1) - 2023-01-26 + +## What's Changed +* Fix Url pymc-bart on documentation by [@fjloyola](https://github.com/fjloyola) in [#34](https://github.com/pymc-devs/pymc-bart/pull/34) +* Fixing issue ThemeError for read the docs by [@fjloyola](https://github.com/fjloyola) in [#37](https://github.com/pymc-devs/pymc-bart/pull/37) +* Refactor to avoid inheritance in BaseNode by [@fjloyola](https://github.com/fjloyola) in [#35](https://github.com/pymc-devs/pymc-bart/pull/35) +* Add link to license by [@PabloGGaray](https://github.com/PabloGGaray) in [#39](https://github.com/pymc-devs/pymc-bart/pull/39) +* Improvements over Tree implementation by [@fjloyola](https://github.com/fjloyola) in [#40](https://github.com/pymc-devs/pymc-bart/pull/40) +* fix import error from pymc 5.0.2 by [@juanitorduz](https://github.com/juanitorduz) in [#43](https://github.com/pymc-devs/pymc-bart/pull/43) +* Update pymc minimum version by [@aloctavodia](https://github.com/aloctavodia) in [#45](https://github.com/pymc-devs/pymc-bart/pull/45) +* Avoid Deepcopy on Tree and ParticleTree by [@fjloyola](https://github.com/fjloyola) in [#47](https://github.com/pymc-devs/pymc-bart/pull/47) + +## New Contributors +* [@fjloyola](https://github.com/fjloyola) made their first contribution in [#34](https://github.com/pymc-devs/pymc-bart/pull/34) +* [@juanitorduz](https://github.com/juanitorduz) made their first contribution in [#43](https://github.com/pymc-devs/pymc-bart/pull/43) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1 + +[Changes][0.3.1] + + + +# [0.3.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.0) - 2022-12-22 + +## What's Changed +* Update README with conda installation by [@maresb](https://github.com/maresb) in [#26](https://github.com/pymc-devs/pymc-bart/pull/26) +* Fix broken URL by [@maresb](https://github.com/maresb) in [#27](https://github.com/pymc-devs/pymc-bart/pull/27) +* Update to PyMC 5 and PyTensor by [@aloctavodia](https://github.com/aloctavodia) in [#29](https://github.com/pymc-devs/pymc-bart/pull/29) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0 + +[Changes][0.3.0] + + + +# [0.2.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.1) - 2022-11-07 + + + +[Changes][0.2.1] + + + +# [0.2.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.0) - 2022-11-03 + + + +[Changes][0.2.0] + + + +# [0.1.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.1.0) - 2022-10-26 + + + +[Changes][0.1.0] + + + +# [0.0.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.0.3) - 2022-09-13 + + + +[Changes][0.0.3] + + +[0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 +[0.8.2]: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 +[0.8.1]: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1 +[0.8.0]: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0 +[0.7.1]: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1 +[0.7.0]: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0 +[0.6.0]: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0 +[0.5.14]: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14 +[0.5.13]: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13 +[0.5.12]: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12 +[0.5.11]: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11 +[0.5.10]: https://github.com/pymc-devs/pymc-bart/compare/0.5.9...0.5.10 +[0.5.9]: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.9 +[0.5.8]: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8 +[0.5.7]: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7 +[0.5.6]: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6 +[0.5.5]: https://github.com/pymc-devs/pymc-bart/compare/0.5.4...0.5.5 +[0.5.4]: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4 +[0.5.3]: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3 +[0.5.2]: https://github.com/pymc-devs/pymc-bart/compare/O.5.1...0.5.2 +[O.5.1]: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1 +[0.5.0]: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0 +[0.4.0]: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0 +[0.3.2]: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2 +[0.3.1]: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1 +[0.3.0]: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0 +[0.2.1]: https://github.com/pymc-devs/pymc-bart/compare/0.2.0...0.2.1 +[0.2.0]: https://github.com/pymc-devs/pymc-bart/compare/0.1.0...0.2.0 +[0.1.0]: https://github.com/pymc-devs/pymc-bart/compare/0.0.3...0.1.0 +[0.0.3]: https://github.com/pymc-devs/pymc-bart/tree/0.0.3 + + From 5ebfba8750327c24d4547fce4ad45ac7d2061bfa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 09:21:27 +0300 Subject: [PATCH 32/53] [pre-commit.ci] pre-commit autoupdate (#221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.9.9 → v0.11.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.9...v0.11.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6a3b804..38c6e50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.9 + rev: v0.11.2 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 4ffa01c3bd0ad9b6819bd8faa68da8231c763106 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 10:09:34 +0300 Subject: [PATCH 33/53] [pre-commit.ci] pre-commit autoupdate (#222) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.2 → v0.11.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.2...v0.11.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38c6e50..d7407f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.2 + rev: v0.11.4 hooks: - id: ruff args: ["--fix", "--output-format=full"] From bdd999cd2c2cc5d5cf83bbe90ade99b191834365 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Apr 2025 11:27:13 +0200 Subject: [PATCH 34/53] [pre-commit.ci] pre-commit autoupdate (#223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.4 → v0.11.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.4...v0.11.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d7407f9..fe8cf35 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.4 + rev: v0.11.5 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 536e5171887a7951e029c7caa04dfc8c2dbb1971 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 May 2025 08:38:31 +0300 Subject: [PATCH 35/53] [pre-commit.ci] pre-commit autoupdate (#224) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.5 → v0.11.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.5...v0.11.9) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe8cf35..7a55cde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 + rev: v0.11.9 hooks: - id: ruff args: ["--fix", "--output-format=full"] From ff3a81ef3e39e9aab0c3d1f13f6caf4317e81037 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Tue, 13 May 2025 11:31:11 +0200 Subject: [PATCH 36/53] misc doc improvements and theme update (#225) * misc doc improvements and theme update * make local search bar text specific * configure search and use pypi release * ignore example notebooks in bart docs from global search --- .readthedocs.yaml | 16 ++++++++++++++-- docs/conf.py | 30 +----------------------------- requirements-docs.txt | 6 ++---- 3 files changed, 17 insertions(+), 35 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0ce9313..691fce7 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,9 +6,9 @@ sphinx: configuration: docs/conf.py build: - os: ubuntu-20.04 + os: ubuntu-24.04 tools: - python: "3.10" + python: "3.12" python: install: @@ -16,3 +16,15 @@ python: - requirements: requirements.txt - method: pip path: . + +search: + ranking: + _sources/*: -10 + _modules/*: -5 + genindex.html: -9 + + ignore: + - 404.html + - search.html + - index.html + - 'examples/*' diff --git a/docs/conf.py b/docs/conf.py index ba89cb1..8945cef 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,6 @@ "sphinx_design", "sphinxcontrib.bibtex", "sphinx_codeautolink", - "sphinx_remove_toctrees", ] # List of patterns, relative to source directory, that match files and @@ -73,6 +72,7 @@ html_theme = "pymc_sphinx_theme" html_theme_options = { "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"], + "search_bar_text": "Search within PyMC-BART...", "navbar_start": ["navbar-logo"], "icon_links": [ { @@ -80,17 +80,6 @@ "icon": "fa-brands fa-github", "name": "GitHub", }, - { - "url": "https://twitter.com/pymc_devs/", - "icon": "fa-brands fa-twitter", - "name": "Twitter", - }, - { - "url": "https://www.youtube.com/c/PyMCDevelopers", - "icon": "fa-brands fa-youtube", - "name": "YouTube", - }, - {"url": "https://discourse.pymc.io", "icon": "fa-brands fa-discourse", "name": "Discourse"}, ], } @@ -144,23 +133,6 @@ nb_execution_mode = "off" -remove_from_toctrees = [ - "BART/*", - "case_studies/*", - "causal_inference/*", - "diagnostics_and_criticism/*", - "gaussian_processes/*", - "generalized_linear_models/*", - "mixture_models/*", - "ode_models/*", - "howto/*", - "samplers/*", - "splines/*", - "survival_analysis/*", - "time_series/*", - "variational_inference/*", -] - # bibtex config bibtex_bibfiles = ["references.bib"] bibtex_default_style = "unsrt" diff --git a/requirements-docs.txt b/requirements-docs.txt index 5074a06..214c399 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,8 +1,6 @@ myst-nb -sphinx==5.0.2 # see https://github.com/pymc-devs/pymc-examples/issues/409 -git+https://github.com/pymc-devs/pymc-sphinx-theme +sphinx +pymc-sphinx-theme>=0.16 sphinxcontrib-bibtex -nbsphinx sphinx_design sphinx_codeautolink -sphinx_remove_toctrees From 5e0ec291c7d8a07db4ea88046fa7f60eae720ade Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 19 May 2025 13:36:59 +0300 Subject: [PATCH 37/53] use last pymc version (#227) --- pymc_bart/__init__.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index ed1a29a..36972b3 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.9.0" +__version__ = "0.9.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/requirements.txt b/requirements.txt index 785de62..e3a38da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2, <=5.20.1 +pymc>=5.16.2, <=5.22.0 arviz>=0.18.0 numba matplotlib From f067a4fbf929bd5f7486feadf7ba2d7bc93e5b16 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 19 May 2025 10:38:00 +0000 Subject: [PATCH 38/53] update changelog for 0.9.1 changes This commit was created by changelog-from-release in 'Post-release' CI workflow --- CHANGELOG.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6577689..99d410f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ + +# [0.9.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.1) - 2025-05-19 + +## What's Changed +* misc doc improvements and theme update by [@OriolAbril](https://github.com/OriolAbril) in [#225](https://github.com/pymc-devs/pymc-bart/pull/225) +* Use last pymc version by [@aloctavodia](https://github.com/aloctavodia) in [#227](https://github.com/pymc-devs/pymc-bart/pull/227) + +## New Contributors +* [@OriolAbril](https://github.com/OriolAbril) made their first contribution in [#225](https://github.com/pymc-devs/pymc-bart/pull/225) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 + +[Changes][0.9.1] + + # [0.9.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.0) - 2025-03-10 @@ -415,6 +430,7 @@ [Changes][0.0.3] +[0.9.1]: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 [0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 [0.8.2]: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 [0.8.1]: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1 From b83161838c6d8b0fee96f7c3e5c8b924390fe07a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 May 2025 10:17:52 +0200 Subject: [PATCH 39/53] [pre-commit.ci] pre-commit autoupdate (#228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.9 → v0.11.11](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.9...v0.11.11) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a55cde..7bd307d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.9 + rev: v0.11.11 hooks: - id: ruff args: ["--fix", "--output-format=full"] From 9d0a079c41e0089773d021cd88a6be34cafcac23 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 12 Jun 2025 14:38:30 +0200 Subject: [PATCH 40/53] Update requirements.txt (#230) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e3a38da..95fce57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2, <=5.22.0 +pymc>=5.16.2, <=5.23.0 arviz>=0.18.0 numba matplotlib From 880cb57bc18fe171b2f8eb20209b7fdf12f665d7 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 12 Jun 2025 15:39:20 +0300 Subject: [PATCH 41/53] Update __init__.py --- pymc_bart/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 36972b3..cfa1648 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.9.1" +__version__ = "0.9.2" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] From cb2aab314efc0748262d9eba51307ccd16f7c6b0 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 12 Jun 2025 12:40:22 +0000 Subject: [PATCH 42/53] update changelog for 0.9.2 changes This commit was created by changelog-from-release in 'Post-release' CI workflow --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99d410f..fc38b99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ + +# [0.9.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.2) - 2025-06-12 + +## What's Changed +* Update requirements.txt by [@juanitorduz](https://github.com/juanitorduz) in [#230](https://github.com/pymc-devs/pymc-bart/pull/230) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.1...0.9.2 + +[Changes][0.9.2] + + # [0.9.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.1) - 2025-05-19 @@ -430,6 +442,7 @@ [Changes][0.0.3] +[0.9.2]: https://github.com/pymc-devs/pymc-bart/compare/0.9.1...0.9.2 [0.9.1]: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 [0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 [0.8.2]: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 From 0000b8c6b63a594e2c7d7db64e4337b4c0698b45 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 21 Jun 2025 18:29:14 +0300 Subject: [PATCH 43/53] Use ArviZ-stats (#232) * use arviz_stats * fix imports * update python versions --- .github/workflows/test.yml | 2 +- docs/api_reference.rst | 2 +- docs/index.rst | 2 +- env-dev.yml | 4 +- env.yml | 4 +- pymc_bart/pgbart.py | 2 +- pymc_bart/utils.py | 89 ++++++++++++++++++-------------------- requirements.txt | 6 +-- setup.py | 4 +- 9 files changed, 56 insertions(+), 59 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8817d27..3fe0779 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] name: Set up Python ${{ matrix.python-version }} steps: diff --git a/docs/api_reference.rst b/docs/api_reference.rst index b6fb8a5..88b910c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART. ============================= .. automodule:: pymc_bart - :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule + :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule diff --git a/docs/index.rst b/docs/index.rst index 78a59fb..e390c3c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ interpretation of those models and perform variable selection. Installation ============ -PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. +PyMC-BART requires a working Python interpreter (3.11+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge. diff --git a/env-dev.yml b/env-dev.yml index 1e28429..fae1398 100644 --- a/env-dev.yml +++ b/env-dev.yml @@ -3,8 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.19.1 - - arviz>=0.18.0 + - pymc>=5.16.2,<=5.23.0 - numba - matplotlib - numpy @@ -20,4 +19,5 @@ dependencies: - flake8 - pip - pip: + - arviz-stats[xarray]>=0.6.0 - -e . diff --git a/env.yml b/env.yml index bd814ae..77f6c13 100644 --- a/env.yml +++ b/env.yml @@ -3,8 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.19.1 - - arviz>=0.18.0 + - pymc>=5.16.2,<=5.23.0 - numba - matplotlib - numpy @@ -12,3 +11,4 @@ dependencies: - pip - pip: - pymc-bart + - arviz-stats[xarray]>=0.6.0 diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 014313a..b76c40c 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -346,7 +346,7 @@ def resample( new_particles.append(particles[idx].copy()) else: new_particles.append(particles[idx]) - seen.append(idx) + seen.append(int(idx)) particles[1:] = new_particles diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 3ba6e58..ab10467 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -4,16 +4,16 @@ import warnings from typing import Any, Callable, Optional, Union -import arviz as az import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from arviz_base import rcParams +from arviz_stats.base import array_stats from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm from .tree import Tree @@ -76,12 +76,12 @@ def _sample_posterior( def plot_convergence( - idata: az.InferenceData, + idata: Any, var_name: Optional[str] = None, kind: str = "ecdf", figsize: Optional[tuple[float, float]] = None, ax=None, -) -> list[plt.Axes]: +) -> None: """ Plot convergence diagnostics. @@ -102,39 +102,12 @@ def plot_convergence( ------- list[ax] : matplotlib axes """ - ess_threshold = idata["posterior"]["chain"].size * 100 - ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) - rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values) - - if figsize is None: - figsize = (10, 3) - - if kind == "ecdf": - kind_func: Callable[..., Any] = az.plot_ecdf - sharey = True - elif kind == "kde": - kind_func = az.plot_kde - sharey = False - - if ax is None: - _, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey) - - for idx, (essi, rhati) in enumerate(zip(ess, rhat)): - kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"}) - kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"}) - - ax[0].axvline(ess_threshold, color="0.7", ls="--") - # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile - # scaled by the sample size and use it as a threshold. - ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--") - - ax[0].set_xlabel("ESS") - ax[1].set_xlabel("R-hat") - if kind == "kde": - ax[0].set_yticks([]) - ax[1].set_yticks([]) - - return ax + warnings.warn( + "This function has been deprecated" + "Use az.plot_convergence_dist() instead." + "https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html", + FutureWarning, + ) def plot_ice( @@ -408,7 +381,7 @@ def identity(x): if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] - hdi = az.hdi(p_di)[idx_uni] + hdi = array_stats.hdi(p_di, prob=rcParams["stats.ci_prob"], axis=0)[idx_uni] axes[count].errorbar( new_x[idx_uni], y_means, @@ -418,11 +391,13 @@ def identity(x): ) axes[count].set_xticks(new_x[idx_uni]) else: - az.plot_hdi( + _plot_hdi( new_x, p_di, smooth=smooth, - fill_kwargs={"alpha": alpha, "color": color}, + alpha=alpha, + color=color, + smooth_kwargs=smooth_kwargs, ax=axes[count], ) if smooth: @@ -659,7 +634,7 @@ def _create_pdp_data( def _smooth_mean( new_x: npt.NDArray, p_di: npt.NDArray, - kind: str = "pdp", + kind: str = "neutral", smooth_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[np.ndarray, np.ndarray]: """ @@ -688,7 +663,10 @@ def _smooth_mean( smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(np.nanmin(new_x), np.nanmax(new_x), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 - if kind == "pdp": + + if kind == "neutral": + interp = griddata(new_x, p_di, x_data) + elif kind == "pdp": interp = griddata(new_x, p_di.mean(0), x_data) else: interp = griddata(new_x, p_di.T, x_data) @@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non def compute_variable_importance( # noqa: PLR0915 PLR0912 - idata: az.InferenceData, + idata: Any, bartrv: Variable, X: npt.NDArray, method: str = "VI", @@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) - r2_hdi[idx] = az.hdi(r_2) + r2_hdi[idx] = array_stats.hdi(r_2, prob=rcParams["stats.ci_prob"]) preds[idx] = predicted_subset.squeeze() if method in ["backward", "backward_VI"]: @@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 # Save values for plotting later r2_mean[i_var - init] = max_r_2 - r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars) + r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars) preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable @@ -1079,7 +1057,7 @@ def plot_variable_importance( ) ax.fill_between( [-0.5, n_vars - 0.5], - *az.hdi(r_2_ref), + *array_stats.hdi(r_2_ref, prob=rcParams["stats.ci_prob"]), alpha=0.1, color=plot_kwargs.get("color_ref", "grey"), ) @@ -1229,3 +1207,22 @@ def pearsonr2(A, B): am = A - np.mean(A) bm = B - np.mean(B) return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2)) + + +def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): + x = np.asarray(x) + y = np.asarray(y) + hdi_prob = rcParams["stats.ci_prob"] + hdi_data = array_stats.hdi(y, hdi_prob, axis=0) + if smooth: + if isinstance(x[0], np.datetime64): + raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.") + + x_data, y_data = _smooth_mean(x, hdi_data, smooth_kwargs=smooth_kwargs) + else: + idx = np.argsort(x) + x_data = x[idx] + y_data = hdi_data[idx] + + ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha) + return ax diff --git a/requirements.txt b/requirements.txt index 95fce57..5e6713e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pymc>=5.16.2, <=5.23.0 -arviz>=0.18.0 +pymc>=5.16.2,<=5.23.0 +arviz-stats[xarray]>=0.6.0 numba matplotlib -numpy +numpy>=2.0 diff --git a/setup.py b/setup.py index e934ae2..0ae76b2 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,9 @@ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", From d9095d922d45bb3b49afdb4f1487ea30d522a663 Mon Sep 17 00:00:00 2001 From: Derek Powell Date: Thu, 17 Jul 2025 01:41:25 -0700 Subject: [PATCH 44/53] Add support for multiple BART random variables per model. (#231) * Add support for multiple BART random variables per model. * add potential fix (conditional of upstream change) * remove int * update pymc version --------- Co-authored-by: aloctavodia --- env-dev.yml | 2 +- env.yml | 2 +- pymc_bart/bart.py | 9 +++--- pymc_bart/pgbart.py | 33 +++++++++++++++++++--- requirements.txt | 2 +- tests/test_bart.py | 67 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 103 insertions(+), 12 deletions(-) diff --git a/env-dev.yml b/env-dev.yml index fae1398..014979c 100644 --- a/env-dev.yml +++ b/env-dev.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.23.0 + - pymc==5.24.0 - numba - matplotlib - numpy diff --git a/env.yml b/env.yml index 77f6c13..f5ebf01 100644 --- a/env.yml +++ b/env.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.23.0 + - pymc==5.24.0 - numba - matplotlib - numpy diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 5114b6e..233d33e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -29,7 +29,6 @@ from pytensor.tensor.variable import TensorVariable from .split_rules import SplitRule -from .tree import Tree from .utils import TensorLike, _sample_posterior __all__ = ["BART"] @@ -42,7 +41,6 @@ class BARTRV(RandomVariable): signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = list[list[list[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed idx = dist_params[0].ndim - 2 @@ -55,7 +53,7 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if not cls.all_trees: + if not hasattr(cls, "all_trees") or not cls.all_trees: if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): Y = cls.Y.eval() else: @@ -142,8 +140,9 @@ def __new__( "Options linear and mix are experimental and still not well tested\n" + "Use with caution." ) + # Create a unique manager list for each BART instance manager = Manager() - cls.all_trees = manager.list() + instance_all_trees = manager.list() X, Y = preprocess_xy(X, Y) @@ -154,7 +153,7 @@ def __new__( (BARTRV,), { "name": "BART", - "all_trees": cls.all_trees, + "all_trees": instance_all_trees, # Instance-specific tree storage "inplace": False, "initval": Y.mean(), "X": X, diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index b76c40c..92f0e21 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -130,6 +130,7 @@ def __init__( # noqa: PLR0912, PLR0915 model: Optional[Model] = None, initial_point: PointType | None = None, compile_kwargs: dict | None = None, + **kwargs, # Accept additional kwargs for compound sampling ) -> None: model = modelcontext(model) if initial_point is None: @@ -143,7 +144,24 @@ def __init__( # noqa: PLR0912, PLR0915 if vars is None: raise ValueError("Unable to find variables to sample") - value_bart = vars[0] + # Filter to only BART variables + bart_vars = [] + for var in vars: + rv = model.values_to_rvs.get(var) + if rv is not None and isinstance(rv.owner.op, BARTRV): + bart_vars.append(var) + + if not bart_vars: + raise ValueError("No BART variables found in the provided variables") + + if len(bart_vars) > 1: + raise ValueError( + "PGBART can only handle one BART variable at a time. " + "For multiple BART variables, PyMC will automatically create " + "separate PGBART samplers for each variable." + ) + + value_bart = bart_vars[0] self.bart = model.values_to_rvs[value_bart].owner.op if isinstance(self.bart.X, Variable): @@ -227,15 +245,15 @@ def __init__( # noqa: PLR0912, PLR0915 self.num_particles = num_particles self.indices = list(range(1, num_particles)) - shared = make_shared_replacements(initial_point, vars, model) - self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared) + shared = make_shared_replacements(initial_point, [value_bart], model) + self.likelihood_logp = logp(initial_point, [model.datalogp], [value_bart], shared) self.all_particles = [ [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) ] self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles]) self.lower = 0 self.iter = 0 - super().__init__(vars, shared) + super().__init__([value_bart], shared) def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") @@ -408,6 +426,13 @@ def competence(var: pm.Distribution, has_grad: bool) -> Competence: return Competence.IDEAL return Competence.INCOMPATIBLE + @staticmethod + def _make_update_stats_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in ("variable_inclusion", "tune")} + + return (update_stats,) + class RunningSd: """Welford's online algorithm for computing the variance/standard deviation""" diff --git a/requirements.txt b/requirements.txt index 5e6713e..2a053a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc>=5.16.2,<=5.23.0 +pymc==5.24.0 arviz-stats[xarray]>=0.6.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index 226d938..8311c2a 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule): # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3) + + +def test_multiple_bart_variables(): + """Test that multiple BART variables can coexist in a single model.""" + X1 = np.random.normal(0, 1, size=(50, 2)) + X2 = np.random.normal(0, 1, size=(50, 3)) + Y = np.random.normal(0, 1, size=50) + + # Create correlated responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=50) + Y2 = X2[:, 0] + X2[:, 1] + np.random.normal(0, 0.1, size=50) + + with pm.Model() as model: + # Two separate BART variables with different covariates + mu1 = pmb.BART("mu1", X1, Y1, m=5) + mu2 = pmb.BART("mu2", X2, Y2, m=5) + + # Combined model + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Sample with automatic assignment of BART samplers + idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415) + + # Verify both BART variables have their own tree collections + assert hasattr(mu1.owner.op, "all_trees") + assert hasattr(mu2.owner.op, "all_trees") + + # Verify trees are stored separately (different object references) + assert mu1.owner.op.all_trees is not mu2.owner.op.all_trees + + # Verify sampling worked + assert idata.posterior["mu1"].shape == (1, 50, 50) + assert idata.posterior["mu2"].shape == (1, 50, 50) + + +def test_multiple_bart_variables_manual_step(): + """Test that multiple BART variables work with manually assigned PGBART samplers.""" + X1 = np.random.normal(0, 1, size=(30, 2)) + X2 = np.random.normal(0, 1, size=(30, 2)) + Y = np.random.normal(0, 1, size=30) + + # Create simple responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=30) + Y2 = X2[:, 1] + np.random.normal(0, 0.1, size=30) + + with pm.Model() as model: + # Two separate BART variables + mu1 = pmb.BART("mu1", X1, Y1, m=3) + mu2 = pmb.BART("mu2", X2, Y2, m=3) + + # Non-BART variable + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Manually create PGBART samplers for each BART variable + step1 = pmb.PGBART([mu1], num_particles=5) + step2 = pmb.PGBART([mu2], num_particles=5) + + # Sample with manual step assignment + idata = pm.sample(tune=20, draws=20, chains=1, step=[step1, step2], random_seed=3415) + + # Verify both variables were sampled + assert "mu1" in idata.posterior + assert "mu2" in idata.posterior + assert idata.posterior["mu1"].shape == (1, 20, 30) + assert idata.posterior["mu2"].shape == (1, 20, 30) From b7567de474cc315edff904103da4540f4ce08e15 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Fri, 18 Jul 2025 12:48:52 +0300 Subject: [PATCH 45/53] encode vi and update to work with multiple RVs (#235) * encode vi and update to work with multiple RVs * add missing tests --- env-dev.yml | 2 +- env.yml | 2 +- pymc_bart/pgbart.py | 5 +- pymc_bart/utils.py | 99 ++++++++++++++++++++++++++++++++------ requirements.txt | 2 +- tests/test_bart.py | 113 +++++++------------------------------------- tests/test_utils.py | 107 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 217 insertions(+), 113 deletions(-) create mode 100644 tests/test_utils.py diff --git a/env-dev.yml b/env-dev.yml index 014979c..375558b 100644 --- a/env-dev.yml +++ b/env-dev.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc==5.24.0 + - pymc>=5.24.0 - numba - matplotlib - numpy diff --git a/env.yml b/env.yml index f5ebf01..3afdd9f 100644 --- a/env.yml +++ b/env.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc==5.24.0 + - pymc>=5.24.0 - numba - matplotlib - numpy diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 92f0e21..87bd36a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -37,6 +37,7 @@ get_idx_left_child, get_idx_right_child, ) +from pymc_bart.utils import _encode_vi class ParticleTree: @@ -118,7 +119,7 @@ class PGBART(ArrayStepShared): default_blocked = False generates_stats = True stats_dtypes_shapes: dict[str, tuple[type, list]] = { - "variable_inclusion": (object, []), + "variable_inclusion": (int, []), "tune": (bool, []), } @@ -335,6 +336,8 @@ def astep(self, _): if not self.tune: self.bart.all_trees.append(self.all_trees) + variable_inclusion = _encode_vi(variable_inclusion) + stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index ab10467..78ce920 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pymc as pm import pytensor.tensor as pt from arviz_base import rcParams from arviz_stats.base import array_stats @@ -674,22 +675,29 @@ def _smooth_mean( return x_data, y_data -def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): +def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None, to_kulprit=False): """ Get the normalized variable inclusion from BART model. Parameters ---------- idata : InferenceData - InferenceData containing a collection of BART_trees in sample_stats group + InferenceData with a variable "variable_inclusion" in ``sample_stats`` group X : npt.NDArray The covariate matrix. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. + bart_var_name : Optional[str] + The name of the BART variable in the model. Only needed if the model contains multiple + BART variables. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. to_kulprit : bool If True, the function will return a list of list with the variables names. This list can be passed as a path to Kulprit's project method. Defaults to False. + Returns ------- VI_norm : npt.NDArray @@ -697,7 +705,20 @@ def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): labels : list[str] List of the names of the covariates. """ - VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + n_vars = X.shape[1] + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None or bart_var_name is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + VIs = np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0) VI_norm = VIs / VIs.sum() idxs = np.argsort(VI_norm) @@ -705,17 +726,15 @@ def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): n_vars = len(indices) if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns + labels = list(X.columns) if labels is None: - labels = np.arange(n_vars).astype(str) - - label_list = labels.to_list() + labels = [str(i) for i in range(n_vars)] if to_kulprit: - return [label_list[:idx] for idx in range(n_vars)] + return [labels[:idx] for idx in range(n_vars)] else: - return VI_norm[indices], label_list + return VI_norm[indices], labels def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): @@ -781,10 +800,11 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: Any, bartrv: Variable, X: npt.NDArray, + model: "pm.Model | None" = None, method: str = "VI", fixed: int = 0, samples: int = 50, - random_seed: Optional[int] = None, + random_seed: int | None = None, ) -> dict[str, object]: """ Estimates variable importance from the BART-posterior. @@ -792,11 +812,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 Parameters ---------- idata : InferenceData - InferenceData containing a collection of BART_trees in sample_stats group + InferenceData containing a "variable_inclusion" variable in the sample_stats group. bartrv : BART Random Variable BART variable once the model that include it has been fitted. X : npt.NDArray The covariate matrix. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. method : str Method used to rank variables. Available options are "VI" (default), "backward" and "backward_VI". @@ -825,6 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 rng = np.random.default_rng(random_seed) all_trees = bartrv.owner.op.all_trees + bart_var_name = bartrv.name if bartrv.ndim == 1: # type: ignore shape = 1 @@ -858,9 +882,20 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 ) if method in ["VI", "backward_VI"]: - idxs = np.argsort( - idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values - ) + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) + + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + idxs = np.argsort(np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0)) subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))] subsets.append(None) # type: ignore @@ -1226,3 +1261,39 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha) return ax + + +def _decode_vi(n: int, length: int) -> list[int]: + """ + Decode the variable inclusion from the BART model. + """ + bits = bin(n)[2:] + vi_list: list[int] = [] + i = 0 + while len(vi_list) < length: + # Count prefix ones + prefix_len = 0 + while bits[i] == "1": + prefix_len += 1 + i += 1 + i += 1 # skip the '0' + b = bits[i : i + prefix_len] + vi_list.append(int(b, 2)) + i += prefix_len + return vi_list + + +def _encode_vi(vec: npt.NDArray) -> int: + """ + Encode variable inclusion vector into a single integer. + + The encoding is done by converting each element of the vector into a binary string, + where each element contributes a prefix of '1's followed by a '0' and its binary representation. + The final result is the integer representation of the concatenated binary string. + """ + bits = "" + for x in vec: + b = bin(x)[2:] + prefix = "1" * len(b) + "0" + bits += prefix + b + return int(bits, 2) diff --git a/requirements.txt b/requirements.txt index 2a053a7..24d156b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc==5.24.0 +pymc>=5.24.0 arviz-stats[xarray]>=0.6.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index 8311c2a..f446cd4 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -1,11 +1,12 @@ import numpy as np import pymc as pm import pytest -from numpy.testing import assert_almost_equal, assert_array_equal +from numpy.testing import assert_almost_equal from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb +from pymc_bart.utils import _decode_vi def assert_moment_is_expected(model, expected, check_finite_logp=True): @@ -52,14 +53,12 @@ def test_bart_vi(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) + pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(tune=200, draws=200, random_seed=3415) - var_imp = ( - idata.sample_stats["variable_inclusion"] - .stack(samples=("chain", "draw")) - .mean("samples") - ) - var_imp /= var_imp.sum() + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + var_imp = np.array([_decode_vi(val, 3) for val in vi_vals]).sum(axis=0) + + var_imp = var_imp / var_imp.sum() assert var_imp[0] > var_imp[1:].sum() assert_almost_equal(var_imp.sum(), 1) @@ -123,92 +122,6 @@ def test_shape(response): assert idata.posterior.coords["w_dim_1"].data.size == 250 -class TestUtils: - X_norm = np.random.normal(0, 1, size=(50, 2)) - X_binom = np.random.binomial(1, 0.5, size=(50, 1)) - X = np.hstack([X_norm, X_binom]) - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=200, draws=200, random_seed=3415) - - def test_sample_posterior(self): - all_trees = self.mu.owner.op.all_trees - rng = np.random.default_rng(3) - pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) - rng = np.random.default_rng(3) - pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) - - assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50, 1) - assert pred_first.shape == (1, 10, 1) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "var_discrete": [3], - }, - {"instances": 2}, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_ice(self, kwargs): - pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "xs_interval": "quantiles", - "xs_values": [0.25, 0.5, 0.75], - "var_discrete": [3], - }, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_pdp(self, kwargs): - pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {"samples": 50}, - {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, - ], - ) - def test_vi(self, kwargs): - samples = kwargs.pop("samples") - vi_results = pmb.compute_variable_importance( - self.idata, bartrv=self.mu, X=self.X, samples=samples - ) - pmb.plot_variable_importance(vi_results, **kwargs) - pmb.plot_scatter_submodels(vi_results, **kwargs) - - def test_pdp_pandas_labels(self): - pd = pytest.importorskip("pandas") - - X_names = ["norm1", "norm2", "binom"] - X_pd = pd.DataFrame(self.X, columns=X_names) - Y_pd = pd.Series(self.Y, name="response") - axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) - - figure = axes[0].figure - assert figure.texts[0].get_text() == "Partial response" - assert_array_equal([ax.get_xlabel() for ax in axes], X_names) - - @pytest.mark.parametrize( "size, expected", [ @@ -275,7 +188,7 @@ def test_multiple_bart_variables(): # Combined model sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + pm.Normal("y", mu1 + mu2, sigma, observed=Y) # Sample with automatic assignment of BART samplers idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415) @@ -291,6 +204,16 @@ def test_multiple_bart_variables(): assert idata.posterior["mu1"].shape == (1, 50, 50) assert idata.posterior["mu2"].shape == (1, 50, 50) + vi_results = pmb.compute_variable_importance(idata, mu1, X1, model=model) + assert vi_results["labels"].shape == (2,) + assert vi_results["preds"].shape == (2, 50, 50) + assert vi_results["preds_all"].shape == (50, 50) + + vi_tuple = pmb.get_variable_inclusion(idata, X1, model=model, bart_var_name="mu1") + assert vi_tuple[0].shape == (2,) + assert len(vi_tuple[1]) == 2 + assert isinstance(vi_tuple[1][0], str) + def test_multiple_bart_variables_manual_step(): """Test that multiple BART variables work with manually assigned PGBART samplers.""" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..dbf3aca --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,107 @@ +import numpy as np +import pymc as pm +import pytest +from numpy.testing import assert_almost_equal, assert_array_equal + +import pymc_bart as pmb + + +class TestUtils: + X_norm = np.random.normal(0, 1, size=(50, 2)) + X_binom = np.random.binomial(1, 0.5, size=(50, 1)) + X = np.hstack([X_norm, X_binom]) + Y = np.random.normal(0, 1, size=50) + + with pm.Model() as model: + mu = pmb.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(tune=200, draws=200, random_seed=3415) + + def test_sample_posterior(self): + all_trees = self.mu.owner.op.all_trees + rng = np.random.default_rng(3) + pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) + rng = np.random.default_rng(3) + pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) + + assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) + assert pred_all.shape == (2, 50, 1) + assert pred_first.shape == (1, 10, 1) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "var_discrete": [3], + }, + {"instances": 2}, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_ice(self, kwargs): + pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "xs_interval": "quantiles", + "xs_values": [0.25, 0.5, 0.75], + "var_discrete": [3], + }, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_pdp(self, kwargs): + pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {"samples": 50}, + {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, + ], + ) + def test_vi(self, kwargs): + samples = kwargs.pop("samples") + vi_results = pmb.compute_variable_importance( + self.idata, bartrv=self.mu, X=self.X, samples=samples + ) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) + + def test_pdp_pandas_labels(self): + pd = pytest.importorskip("pandas") + + X_names = ["norm1", "norm2", "binom"] + X_pd = pd.DataFrame(self.X, columns=X_names) + Y_pd = pd.Series(self.Y, name="response") + axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) + + figure = axes[0].figure + assert figure.texts[0].get_text() == "Partial response" + assert_array_equal([ax.get_xlabel() for ax in axes], X_names) + + +def test_encoder_decoder(): + """Test that the encoder-decoder works correctly.""" + test_cases = [ + np.zeros(3, dtype=int), + np.ones(10, dtype=int), + np.array([4, 0, 1, 0, 2, 0, 3, 0, 0, 0]), + np.array([100, 50, 0, 1]), + np.array([1, 2, 4, 8, 16]), + ] + for case in test_cases: + encoded = pmb.utils._encode_vi(case) + decoded = pmb.utils._decode_vi(encoded, len(case)) + assert np.array_equal(decoded, case) From b584b23e32d7b71ebf659876f0318a83bf9db394 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Fri, 18 Jul 2025 19:05:42 +0300 Subject: [PATCH 46/53] bump release (#236) --- pymc_bart/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index cfa1648..b8cf0a6 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -42,7 +42,7 @@ "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.9.2" +__version__ = "0.10.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] From 941d258e679e2c57b0571dc953e6bea657706fdb Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 18 Jul 2025 16:07:09 +0000 Subject: [PATCH 47/53] update changelog for 0.10.0 changes This commit was created by changelog-from-release in 'Post-release' CI workflow --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc38b99..8dc2496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,20 @@ + +# [0.10.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.10.0) - 2025-07-18 + +## What's Changed +* Use ArviZ-stats by [@aloctavodia](https://github.com/aloctavodia) in [#232](https://github.com/pymc-devs/pymc-bart/pull/232) +* Add support for multiple BART random variables per model. by [@derekpowell](https://github.com/derekpowell) in [#231](https://github.com/pymc-devs/pymc-bart/pull/231) +* encode vi and update to work with multiple RVs by [@aloctavodia](https://github.com/aloctavodia) in [#235](https://github.com/pymc-devs/pymc-bart/pull/235) + + +## New Contributors +* [@derekpowell](https://github.com/derekpowell) made their first contribution in [#231](https://github.com/pymc-devs/pymc-bart/pull/231) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.2...0.10.0 + +[Changes][0.10.0] + + # [0.9.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.2) - 2025-06-12 @@ -442,6 +459,7 @@ [Changes][0.0.3] +[0.10.0]: https://github.com/pymc-devs/pymc-bart/compare/0.9.2...0.10.0 [0.9.2]: https://github.com/pymc-devs/pymc-bart/compare/0.9.1...0.9.2 [0.9.1]: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 [0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 From 9710d7f4dcbf7c97fb7403dea1083b268fb85d51 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Wed, 27 Aug 2025 12:40:05 +0300 Subject: [PATCH 48/53] add function to export the vi results to kulprit (#237) --- pymc_bart/__init__.py | 2 ++ pymc_bart/utils.py | 18 ++++++++++++++++++ tests/test_utils.py | 4 ++++ 3 files changed, 24 insertions(+) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index b8cf0a6..cfc1afc 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -25,6 +25,7 @@ plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, + vi_to_kulprit, ) __all__ = [ @@ -41,6 +42,7 @@ "plot_scatter_submodels", "plot_variable_importance", "plot_variable_inclusion", + "vi_to_kulprit", ] __version__ = "0.10.0" diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 78ce920..cf804c5 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1006,6 +1006,24 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 return vi_results +def vi_to_kulprit(vi_results: dict) -> list[list[str]]: + """ + Export variable importance results to Kulprit format. + + Parameters + ---------- + vi_results : dict + Dictionary computed with `compute_variable_importance` + + Returns + ------- + list[list[str]] + A list of lists containing variable names for each submodel. + """ + clean_labels = [label.strip("+ ") for label in vi_results["labels"]] + return [clean_labels[:idx] for idx in range(len(clean_labels))] + + def plot_variable_importance( vi_results: dict, submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, diff --git a/tests/test_utils.py b/tests/test_utils.py index dbf3aca..ed85af7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,6 +79,10 @@ def test_vi(self, kwargs): pmb.plot_variable_importance(vi_results, **kwargs) pmb.plot_scatter_submodels(vi_results, **kwargs) + user_terms = pmb.vi_to_kulprit(vi_results) + assert len(user_terms) == 3 + assert all("+" not in term for terms in user_terms[1:] for term in terms) + def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas") From d87eb242fce910737d1084bbcd12f9ed342cec9c Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Wed, 27 Aug 2025 13:16:01 +0300 Subject: [PATCH 49/53] Fix formatting and update ruff linting rules --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a2273d..2773123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ line-length = 100 select = ["E", "F", "I", "PL", "UP", "W"] ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. - "PLR0913", #Too many arguments in function definition - + "PLR0913", # Too many arguments in function definition + "PLC0415", # import should be at the top-level ] [tool.ruff.lint.pylint] From 5d0723cfed47c960841c546ee82cc89d154d8e72 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Wed, 27 Aug 2025 14:20:09 +0300 Subject: [PATCH 50/53] fix mypy issues (#238) s --- pymc_bart/utils.py | 93 +++++++++++++++++++++++----------------------- pyproject.toml | 1 - 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index cf804c5..dfb5eac 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -2,7 +2,8 @@ """Utility function for variable selection and bart interpretability.""" import warnings -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, TypeVar import matplotlib.pyplot as plt import numpy as np @@ -18,15 +19,15 @@ from .tree import Tree -TensorLike = Union[npt.NDArray, pt.TensorVariable] +TensorLike = TypeVar("TensorLike", npt.NDArray, pt.TensorVariable) def _sample_posterior( all_trees: list[list[Tree]], X: TensorLike, rng: np.random.Generator, - size: Optional[Union[int, tuple[int, ...]]] = None, - excluded: Optional[list[int]] = None, + size: int | tuple[int, ...] | None = None, + excluded: list[int] | None = None, shape: int = 1, ) -> npt.NDArray: """ @@ -51,7 +52,7 @@ def _sample_posterior( X = X.eval() if size is None: - size_iter: Union[list, tuple] = (1,) + size_iter: list | tuple = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -78,9 +79,9 @@ def _sample_posterior( def plot_convergence( idata: Any, - var_name: Optional[str] = None, + var_name: str | None = None, kind: str = "ecdf", - figsize: Optional[tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, ax=None, ) -> None: """ @@ -114,23 +115,23 @@ def plot_convergence( def plot_ice( bartrv: Variable, X: npt.NDArray, - Y: Optional[npt.NDArray] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, - func: Optional[Callable] = None, - centered: Optional[bool] = True, + Y: npt.NDArray | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, + centered: bool | None = True, samples: int = 100, instances: int = 30, - random_seed: Optional[int] = None, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[tuple[float, float]] = None, - smooth_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ) -> list[plt.Axes]: """ Individual conditional expectation plot. @@ -258,24 +259,24 @@ def identity(x): def plot_pdp( bartrv: Variable, X: npt.NDArray, - Y: Optional[npt.NDArray] = None, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, list[float]]] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, - func: Optional[Callable] = None, + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, samples: int = 200, ref_line: bool = True, - random_seed: Optional[int] = None, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[tuple[float, float]] = None, - smooth_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes = None, ) -> list[plt.Axes]: """ Partial dependence plot. @@ -425,8 +426,8 @@ def _create_figure_axes( var_idx: list[int], grid: str = "long", sharey: bool = True, - figsize: Optional[tuple[float, float]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + ax: plt.Axes | None = None, ) -> tuple[plt.Figure, list[plt.Axes], int]: """ Create and return the figure and axes objects for plotting the variables. @@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize): def _prepare_plot_data( X: npt.NDArray, - Y: Optional[npt.NDArray] = None, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, list[float]]] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, ) -> tuple[ npt.NDArray, list[str], @@ -519,7 +520,7 @@ def _prepare_plot_data( list[int], list[int], str, - Union[int, None, list[float]], + int | None | list[float], ]: """ Prepare data for plotting. @@ -600,7 +601,7 @@ def _prepare_plot_data( def _create_pdp_data( X: npt.NDArray, xs_interval: str, - xs_values: Optional[Union[int, list[float]]] = None, + xs_values: int | list[float] | None = None, ) -> npt.NDArray: """ Create data for partial dependence plot. @@ -636,7 +637,7 @@ def _smooth_mean( new_x: npt.NDArray, p_di: npt.NDArray, kind: str = "neutral", - smooth_kwargs: Optional[dict[str, Any]] = None, + smooth_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Smooth the mean data for plotting. @@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 fixed: int = 0, samples: int = 50, random_seed: int | None = None, -) -> dict[str, object]: +) -> dict[str, npt.NDArray]: """ Estimates variable importance from the BART-posterior. @@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]: def plot_variable_importance( vi_results: dict, - submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, - labels: Optional[list[str]] = None, - figsize: Optional[tuple[float, float]] = None, - plot_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + submodels: list[int] | np.ndarray | tuple[int, ...] | None = None, + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ): """ Estimates variable importance from the BART-posterior. @@ -1128,13 +1129,13 @@ def plot_variable_importance( def plot_scatter_submodels( vi_results: dict, - func: Optional[Callable] = None, - submodels: Optional[Union[list[int], np.ndarray]] = None, + func: Callable | None = None, + submodels: list[int] | np.ndarray | None = None, grid: str = "long", - labels: Optional[list[str]] = None, - figsize: Optional[tuple[float, float]] = None, - plot_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ) -> list[plt.Axes]: """ Plot submodel's predictions against reference-model's predictions. diff --git a/pyproject.toml b/pyproject.toml index 2773123..2afa2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ pyupgrade = 1 [tool.mypy] files = "pymc_bart/*.py" -plugins = "numpy.typing.mypy_plugin" [tool.mypy-matplotlib] ignore_missing_imports = true From 6723677df6e4c7de236d2055bdd08f21198b21c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Aug 2025 15:01:03 +0300 Subject: [PATCH 51/53] [pre-commit.ci] pre-commit autoupdate (#229) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.11 → v0.12.10](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.11...v0.12.10) - [github.com/pre-commit/mirrors-mypy: v1.15.0 → v1.17.1](https://github.com/pre-commit/mirrors-mypy/compare/v1.15.0...v1.17.1) - [github.com/pre-commit/pre-commit-hooks: v5.0.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v5.0.0...v6.0.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7bd307d..71219b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,21 +12,21 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.11 + rev: v0.12.10 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.17.1 hooks: - id: mypy args: [--ignore-missing-imports] files: ^pymc_bart/ additional_dependencies: [numpy, pandas-stubs] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer From 5de4221767bc8c06d398edee78eb14cfdc0d440f Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Wed, 27 Aug 2025 15:09:03 +0300 Subject: [PATCH 52/53] Adjust precision in NormalSampler test assertions Updated decimal precision for mean and std assertions in NormalSampler tests. --- tests/test_pgbart.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 4cf4188..5a1d35e 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -74,13 +74,13 @@ def test_discrete_uniform(): def test_normal_sampler(): normal = NormalSampler(2, shape=1) samples = np.array([normal.rvs() for i in range(100000)]) - np.testing.assert_almost_equal(samples.mean(), 0, decimal=2) - np.testing.assert_almost_equal(samples.std(), 2, decimal=2) + np.testing.assert_almost_equal(samples.mean(), 0, decimal=1) + np.testing.assert_almost_equal(samples.std(), 2, decimal=1) normal = NormalSampler(2, shape=2) samples = np.array([normal.rvs() for i in range(100000)]) - np.testing.assert_almost_equal(samples.mean(0), [0, 0], decimal=2) - np.testing.assert_almost_equal(samples.std(0), [2, 2], decimal=2) + np.testing.assert_almost_equal(samples.mean(0), [0, 0], decimal=1) + np.testing.assert_almost_equal(samples.std(0), [2, 2], decimal=1) def test_uniform_sampler(): From caa79a8ed5a32993feceb9cb172f831d49acaa5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 21:04:28 +0300 Subject: [PATCH 53/53] [pre-commit.ci] pre-commit autoupdate (#239) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.12.10 → v0.12.12](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.10...v0.12.12) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71219b1..cc5a6be 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.10 + rev: v0.12.12 hooks: - id: ruff args: ["--fix", "--output-format=full"]