diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml new file mode 100644 index 0000000..5526a27 --- /dev/null +++ b/.github/workflows/post-release.yml @@ -0,0 +1,19 @@ +name: Post-release +on: + release: + types: [published, released] + workflow_dispatch: + +jobs: + changelog: + name: Update changelog + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + - uses: rhysd/changelog-from-release/action@v3 + with: + file: CHANGELOG.md + github_token: ${{ secrets.GITHUB_TOKEN }} + commit_summary_template: 'update changelog for %s changes' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8817d27..3fe0779 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] name: Set up Python ${{ matrix.python-version }} steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a21dcf..cc5a6be 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,21 +12,21 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.12.12 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.17.1 hooks: - id: mypy args: [--ignore-missing-imports] files: ^pymc_bart/ additional_dependencies: [numpy, pandas-stubs] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6e5cef0..691fce7 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,11 +1,14 @@ # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: docs/conf.py build: - os: ubuntu-20.04 + os: ubuntu-24.04 tools: - python: "3.10" + python: "3.12" python: install: @@ -13,3 +16,15 @@ python: - requirements: requirements.txt - method: pip path: . + +search: + ranking: + _sources/*: -10 + _modules/*: -5 + genindex.html: -9 + + ignore: + - 404.html + - search.html + - index.html + - 'examples/*' diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..8dc2496 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,496 @@ + +# [0.10.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.10.0) - 2025-07-18 + +## What's Changed +* Use ArviZ-stats by [@aloctavodia](https://github.com/aloctavodia) in [#232](https://github.com/pymc-devs/pymc-bart/pull/232) +* Add support for multiple BART random variables per model. by [@derekpowell](https://github.com/derekpowell) in [#231](https://github.com/pymc-devs/pymc-bart/pull/231) +* encode vi and update to work with multiple RVs by [@aloctavodia](https://github.com/aloctavodia) in [#235](https://github.com/pymc-devs/pymc-bart/pull/235) + + +## New Contributors +* [@derekpowell](https://github.com/derekpowell) made their first contribution in [#231](https://github.com/pymc-devs/pymc-bart/pull/231) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.2...0.10.0 + +[Changes][0.10.0] + + + +# [0.9.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.2) - 2025-06-12 + +## What's Changed +* Update requirements.txt by [@juanitorduz](https://github.com/juanitorduz) in [#230](https://github.com/pymc-devs/pymc-bart/pull/230) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.1...0.9.2 + +[Changes][0.9.2] + + + +# [0.9.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.1) - 2025-05-19 + +## What's Changed +* misc doc improvements and theme update by [@OriolAbril](https://github.com/OriolAbril) in [#225](https://github.com/pymc-devs/pymc-bart/pull/225) +* Use last pymc version by [@aloctavodia](https://github.com/aloctavodia) in [#227](https://github.com/pymc-devs/pymc-bart/pull/227) + +## New Contributors +* [@OriolAbril](https://github.com/OriolAbril) made their first contribution in [#225](https://github.com/pymc-devs/pymc-bart/pull/225) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 + +[Changes][0.9.1] + + + +# [0.9.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.0) - 2025-03-10 + +## What's Changed +* Update MyPy 14 by [@juanitorduz](https://github.com/juanitorduz) in [#210](https://github.com/pymc-devs/pymc-bart/pull/210) +* Automatic Changelog by [@aloctavodia](https://github.com/aloctavodia) in [#213](https://github.com/pymc-devs/pymc-bart/pull/213) +* Adds get_variable_inclusion function by [@aloctavodia](https://github.com/aloctavodia) in [#214](https://github.com/pymc-devs/pymc-bart/pull/214) +* Refactor rng_fn method by [@aloctavodia](https://github.com/aloctavodia) in [#212](https://github.com/pymc-devs/pymc-bart/pull/212) +* Fix docs by adding path of config by [@juanitorduz](https://github.com/juanitorduz) in [#217](https://github.com/pymc-devs/pymc-bart/pull/217) +* Enhance `plot_pdp` and fix `plot_scatter_submodels` by [@AlexAndorra](https://github.com/AlexAndorra) in [#218](https://github.com/pymc-devs/pymc-bart/pull/218) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 + +[Changes][0.9.0] + + + +# [0.8.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.2) - 2024-12-23 + +## What's Changed +* Compute_variable_importance: fix bug with non-default shapes by [@aloctavodia](https://github.com/aloctavodia) in [#208](https://github.com/pymc-devs/pymc-bart/pull/208) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 + +[Changes][0.8.2] + + + +# [0.8.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.1) - 2024-12-20 + +## What's Changed +* Patch for case when Y is a TensorVariable by [@AlexAndorra](https://github.com/AlexAndorra) in [#206](https://github.com/pymc-devs/pymc-bart/pull/206) +* Fix bug with labels in variable importance, add reference line, remove deprecation warning by [@aloctavodia](https://github.com/aloctavodia) in [#207](https://github.com/pymc-devs/pymc-bart/pull/207) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1 + +[Changes][0.8.1] + + + +# [0.8.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.0) - 2024-12-17 + +## What's Changed + +* Add new vi plots by [@aloctavodia](https://github.com/aloctavodia) in [#196](https://github.com/pymc-devs/pymc-bart/pull/196) +* Allows plotting a subset of the variables once the variable's importance has been computed by [@aloctavodia](https://github.com/aloctavodia) in [#200](https://github.com/pymc-devs/pymc-bart/pull/200) +* Enable passing `Y` as a `SharedVariable` to `pm.Bart` by [@AlexAndorra](https://github.com/AlexAndorra) in [#202](https://github.com/pymc-devs/pymc-bart/pull/202) +* Improve docs, aesthetics and functionality by [@aloctavodia](https://github.com/aloctavodia) in [#198](https://github.com/pymc-devs/pymc-bart/pull/198) + + +## New Contributors +* [@AlexAndorra](https://github.com/AlexAndorra) made their first contribution in [#202](https://github.com/pymc-devs/pymc-bart/pull/202) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0 + +[Changes][0.8.0] + + + +# [0.7.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.1) - 2024-11-07 + +## What's Changed +* Conform to recent changes in pymc by [@aloctavodia](https://github.com/aloctavodia) in [#194](https://github.com/pymc-devs/pymc-bart/pull/194) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1 + +[Changes][0.7.1] + + + +# [0.7.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.0) - 2024-09-05 + +## What's Changed +* Allow Y to be a tensor by [@aloctavodia](https://github.com/aloctavodia) in [#180](https://github.com/pymc-devs/pymc-bart/pull/180) +* improve plot_variable_importance by [@aloctavodia](https://github.com/aloctavodia) in [#182](https://github.com/pymc-devs/pymc-bart/pull/182) +* move x_angle to plot_kwargs by [@aloctavodia](https://github.com/aloctavodia) in [#185](https://github.com/pymc-devs/pymc-bart/pull/185) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0 + +[Changes][0.7.0] + + + +# [0.6.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.6.0) - 2024-08-16 + +## What's Changed +* Add categorical example by [@PabloGGaray](https://github.com/PabloGGaray) in [#167](https://github.com/pymc-devs/pymc-bart/pull/167) +* Fix np.float_ type by [@juanitorduz](https://github.com/juanitorduz) in [#171](https://github.com/pymc-devs/pymc-bart/pull/171) +* Support Polars by [@aloctavodia](https://github.com/aloctavodia) in [#179](https://github.com/pymc-devs/pymc-bart/pull/179) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0 + +[Changes][0.6.0] + + + +# [0.5.14](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.14) - 2024-05-14 + +## What's Changed +* Less than equal PyMC Version by [@juanitorduz](https://github.com/juanitorduz) in [#164](https://github.com/pymc-devs/pymc-bart/pull/164) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14 + +[Changes][0.5.14] + + + +# [0.5.13](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.13) - 2024-05-13 + +## What's Changed +* Update pymc version requirements.txt by [@juanitorduz](https://github.com/juanitorduz) in [#163](https://github.com/pymc-devs/pymc-bart/pull/163) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13 + +[Changes][0.5.13] + + + +# [0.5.12](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.12) - 2024-04-18 + +## What's Changed +* Unpin numpy by [@maresb](https://github.com/maresb) in [#156](https://github.com/pymc-devs/pymc-bart/pull/156) +* Resolve deprecation warning for `pytensor`'s `Variable` by [@RyanAugust](https://github.com/RyanAugust) in [#159](https://github.com/pymc-devs/pymc-bart/pull/159) + +## New Contributors +* [@RyanAugust](https://github.com/RyanAugust) made their first contribution in [#159](https://github.com/pymc-devs/pymc-bart/pull/159) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12 + +[Changes][0.5.12] + + + +# [0.5.11](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.11) - 2024-03-15 + +## What's Changed +* Add citation file by [@PabloGGaray](https://github.com/PabloGGaray) in [#151](https://github.com/pymc-devs/pymc-bart/pull/151) +* Rename moment to support_point by [@PabloGGaray](https://github.com/PabloGGaray) in [#154](https://github.com/pymc-devs/pymc-bart/pull/154) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11 + +[Changes][0.5.11] + + + +# [0.5.10](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.10) - 2024-03-14 + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.10 + +[Changes][0.5.10] + + + +# [0.5.9](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.9) - 2024-03-14 + +## What's Changed +* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140) +* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141) +* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +## New Contributors +* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.9 + +[Changes][0.5.9] + + + +# [0.5.8](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.8) - 2024-03-14 + +## What's Changed +* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140) +* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141) +* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + + +## New Contributors +* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8 + +[Changes][0.5.8] + + + +# [0.5.7](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.7) - 2023-12-29 + +## What's Changed +* Properly handle nans when jittering by [@aloctavodia](https://github.com/aloctavodia) in [#136](https://github.com/pymc-devs/pymc-bart/pull/136) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7 + +[Changes][0.5.7] + + + +# [0.5.6](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.6) - 2023-12-23 + +## What's Changed +* Fix bug in plot_ice, and clean docstring of plot_ice and plot_pdp by [@aloctavodia](https://github.com/aloctavodia) in [#135](https://github.com/pymc-devs/pymc-bart/pull/135) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6 + +[Changes][0.5.6] + + + +# [0.5.5](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.5) - 2023-12-22 + +## What's Changed +* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129) +* link GitHub icon to pymc-bart repo by [@aloctavodia](https://github.com/aloctavodia) in [#131](https://github.com/pymc-devs/pymc-bart/pull/131) +* VI remove unnecessary evaluations for the backward method by [@aloctavodia](https://github.com/aloctavodia) in [#132](https://github.com/pymc-devs/pymc-bart/pull/132) +* jitter only arrays of whole numbers by [@aloctavodia](https://github.com/aloctavodia) in [#133](https://github.com/pymc-devs/pymc-bart/pull/133) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.5 + +[Changes][0.5.5] + + + +# [0.5.4](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.4) - 2023-11-21 + +## What's Changed +* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4 + +[Changes][0.5.4] + + + +# [0.5.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.3) - 2023-11-18 + +## What's Changed +* improve variable importance computation by adding backward method by [@aloctavodia](https://github.com/aloctavodia) in [#125](https://github.com/pymc-devs/pymc-bart/pull/125) +* set new paths to notebooks by [@aloctavodia](https://github.com/aloctavodia) in [#126](https://github.com/pymc-devs/pymc-bart/pull/126) +* fix case examples by [@aloctavodia](https://github.com/aloctavodia) in [#127](https://github.com/pymc-devs/pymc-bart/pull/127) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3 + +[Changes][0.5.3] + + + +# [0.5.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.2) - 2023-10-27 + +## What's Changed +* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108) +* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107) +* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109) +* Add issue templates by [@PabloGGaray](https://github.com/PabloGGaray) in [#113](https://github.com/pymc-devs/pymc-bart/pull/113) +* Add conda option by [@PabloGGaray](https://github.com/PabloGGaray) in [#114](https://github.com/pymc-devs/pymc-bart/pull/114) +* fix split_prior bug by [@aloctavodia](https://github.com/aloctavodia) in [#115](https://github.com/pymc-devs/pymc-bart/pull/115) +* Add logo by [@aloctavodia](https://github.com/aloctavodia) in [#116](https://github.com/pymc-devs/pymc-bart/pull/116) +* clean logo by [@aloctavodia](https://github.com/aloctavodia) in [#117](https://github.com/pymc-devs/pymc-bart/pull/117) +* Add plot_ice to API description on the webpage by [@PabloGGaray](https://github.com/PabloGGaray) in [#119](https://github.com/pymc-devs/pymc-bart/pull/119) +* Better handling of discrete variables and other minor fixes by [@aloctavodia](https://github.com/aloctavodia) in [#121](https://github.com/pymc-devs/pymc-bart/pull/121) + + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...0.5.2 + +[Changes][0.5.2] + + + +# [O.5.1](https://github.com/pymc-devs/pymc-bart/releases/tag/O.5.1) - 2023-07-12 + +## What's Changed +* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108) +* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107) +* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1 + +[Changes][O.5.1] + + + +# [0.5.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.0) - 2023-07-10 + +## What's Changed +* Add pre-commit hooks by [@juanitorduz](https://github.com/juanitorduz) in [#75](https://github.com/pymc-devs/pymc-bart/pull/75) +* Add mypy init by [@juanitorduz](https://github.com/juanitorduz) in [#78](https://github.com/pymc-devs/pymc-bart/pull/78) +* Do not store index at each node. by [@howsiyu](https://github.com/howsiyu) in [#80](https://github.com/pymc-devs/pymc-bart/pull/80) +* Add linear response [@juanitorduz](https://github.com/juanitorduz) in [#79](https://github.com/pymc-devs/pymc-bart/pull/79) +* Do weighted mean when pruning by [@aloctavodia](https://github.com/aloctavodia) in [#83](https://github.com/pymc-devs/pymc-bart/pull/83) +* Implement fast version of pdp by [@aloctavodia](https://github.com/aloctavodia) in [#85](https://github.com/pymc-devs/pymc-bart/pull/85) +* Add error bars to variable importance by [@aloctavodia](https://github.com/aloctavodia) in [#90](https://github.com/pymc-devs/pymc-bart/pull/90) +* Compute running variance for leaf nodes by [@aloctavodia](https://github.com/aloctavodia) in [#91](https://github.com/pymc-devs/pymc-bart/pull/91) +* Improve doc style and add missing examples by [@aloctavodia](https://github.com/aloctavodia) in [#92](https://github.com/pymc-devs/pymc-bart/pull/92) +* Make the Repo more welcoming with a clear title by [@juanitorduz](https://github.com/juanitorduz) in [#94](https://github.com/pymc-devs/pymc-bart/pull/94) +* Improve docstrings new alpha and beta parameters by [@juanitorduz](https://github.com/juanitorduz) in [#95](https://github.com/pymc-devs/pymc-bart/pull/95) +* Allow different splitting rules by [@velochy](https://github.com/velochy) in [#96](https://github.com/pymc-devs/pymc-bart/pull/96) +* Allow training separate tree structures if training multiple trees by [@velochy](https://github.com/velochy) in [#98](https://github.com/pymc-devs/pymc-bart/pull/98) + +## New Contributors +* [@howsiyu](https://github.com/howsiyu) made their first contribution in [#80](https://github.com/pymc-devs/pymc-bart/pull/80) +* [@velochy](https://github.com/velochy) made their first contribution in [#96](https://github.com/pymc-devs/pymc-bart/pull/96) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0 + +[Changes][0.5.0] + + + +# [0.4.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.4.0) - 2023-03-17 + +## What's Changed +* fig bug systematic resampling and add func argument by [@aloctavodia](https://github.com/aloctavodia) in [#61](https://github.com/pymc-devs/pymc-bart/pull/61) and [#66](https://github.com/pymc-devs/pymc-bart/pull/66) +* add tests for individual functions/methods in PGBART by [@aloctavodia](https://github.com/aloctavodia) in [#64](https://github.com/pymc-devs/pymc-bart/pull/64) +* Modify resampling schema and refactor by [@aloctavodia](https://github.com/aloctavodia) in [#65](https://github.com/pymc-devs/pymc-bart/pull/65) +* add plot_convergence by [@aloctavodia](https://github.com/aloctavodia) in [#67](https://github.com/pymc-devs/pymc-bart/pull/67) and [@aloctavodia](https://github.com/aloctavodia) in [#68](https://github.com/pymc-devs/pymc-bart/pull/68) +* Improve plot_dependence by [@PabloGGaray](https://github.com/PabloGGaray) in [#70](https://github.com/pymc-devs/pymc-bart/pull/70) and [@aloctavodia](https://github.com/aloctavodia) in [#71](https://github.com/pymc-devs/pymc-bart/pull/71) and in [#73](https://github.com/pymc-devs/pymc-bart/pull/73) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0 + +[Changes][0.4.0] + + + +# [0.3.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.2) - 2023-02-03 + +## What's Changed +* Refactor and [@njit](https://github.com/njit) on methods by [@fjloyola](https://github.com/fjloyola) in [#54](https://github.com/pymc-devs/pymc-bart/pull/54) +* Fix shape error [@aloctavodia](https://github.com/aloctavodia) in [#57](https://github.com/pymc-devs/pymc-bart/pull/57) and [#59](https://github.com/pymc-devs/pymc-bart/pull/59) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2 + +[Changes][0.3.2] + + + +# [0.3.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.1) - 2023-01-26 + +## What's Changed +* Fix Url pymc-bart on documentation by [@fjloyola](https://github.com/fjloyola) in [#34](https://github.com/pymc-devs/pymc-bart/pull/34) +* Fixing issue ThemeError for read the docs by [@fjloyola](https://github.com/fjloyola) in [#37](https://github.com/pymc-devs/pymc-bart/pull/37) +* Refactor to avoid inheritance in BaseNode by [@fjloyola](https://github.com/fjloyola) in [#35](https://github.com/pymc-devs/pymc-bart/pull/35) +* Add link to license by [@PabloGGaray](https://github.com/PabloGGaray) in [#39](https://github.com/pymc-devs/pymc-bart/pull/39) +* Improvements over Tree implementation by [@fjloyola](https://github.com/fjloyola) in [#40](https://github.com/pymc-devs/pymc-bart/pull/40) +* fix import error from pymc 5.0.2 by [@juanitorduz](https://github.com/juanitorduz) in [#43](https://github.com/pymc-devs/pymc-bart/pull/43) +* Update pymc minimum version by [@aloctavodia](https://github.com/aloctavodia) in [#45](https://github.com/pymc-devs/pymc-bart/pull/45) +* Avoid Deepcopy on Tree and ParticleTree by [@fjloyola](https://github.com/fjloyola) in [#47](https://github.com/pymc-devs/pymc-bart/pull/47) + +## New Contributors +* [@fjloyola](https://github.com/fjloyola) made their first contribution in [#34](https://github.com/pymc-devs/pymc-bart/pull/34) +* [@juanitorduz](https://github.com/juanitorduz) made their first contribution in [#43](https://github.com/pymc-devs/pymc-bart/pull/43) + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1 + +[Changes][0.3.1] + + + +# [0.3.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.0) - 2022-12-22 + +## What's Changed +* Update README with conda installation by [@maresb](https://github.com/maresb) in [#26](https://github.com/pymc-devs/pymc-bart/pull/26) +* Fix broken URL by [@maresb](https://github.com/maresb) in [#27](https://github.com/pymc-devs/pymc-bart/pull/27) +* Update to PyMC 5 and PyTensor by [@aloctavodia](https://github.com/aloctavodia) in [#29](https://github.com/pymc-devs/pymc-bart/pull/29) + + +**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0 + +[Changes][0.3.0] + + + +# [0.2.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.1) - 2022-11-07 + + + +[Changes][0.2.1] + + + +# [0.2.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.0) - 2022-11-03 + + + +[Changes][0.2.0] + + + +# [0.1.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.1.0) - 2022-10-26 + + + +[Changes][0.1.0] + + + +# [0.0.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.0.3) - 2022-09-13 + + + +[Changes][0.0.3] + + +[0.10.0]: https://github.com/pymc-devs/pymc-bart/compare/0.9.2...0.10.0 +[0.9.2]: https://github.com/pymc-devs/pymc-bart/compare/0.9.1...0.9.2 +[0.9.1]: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1 +[0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0 +[0.8.2]: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2 +[0.8.1]: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1 +[0.8.0]: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0 +[0.7.1]: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1 +[0.7.0]: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0 +[0.6.0]: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0 +[0.5.14]: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14 +[0.5.13]: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13 +[0.5.12]: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12 +[0.5.11]: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11 +[0.5.10]: https://github.com/pymc-devs/pymc-bart/compare/0.5.9...0.5.10 +[0.5.9]: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.9 +[0.5.8]: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8 +[0.5.7]: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7 +[0.5.6]: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6 +[0.5.5]: https://github.com/pymc-devs/pymc-bart/compare/0.5.4...0.5.5 +[0.5.4]: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4 +[0.5.3]: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3 +[0.5.2]: https://github.com/pymc-devs/pymc-bart/compare/O.5.1...0.5.2 +[O.5.1]: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1 +[0.5.0]: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0 +[0.4.0]: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0 +[0.3.2]: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2 +[0.3.1]: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1 +[0.3.0]: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0 +[0.2.1]: https://github.com/pymc-devs/pymc-bart/compare/0.2.0...0.2.1 +[0.2.0]: https://github.com/pymc-devs/pymc-bart/compare/0.1.0...0.2.0 +[0.1.0]: https://github.com/pymc-devs/pymc-bart/compare/0.0.3...0.1.0 +[0.0.3]: https://github.com/pymc-devs/pymc-bart/tree/0.0.3 + + diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 93afde1..88b910c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART. ============================= .. automodule:: pymc_bart - :members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule + :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 0000000..f83d445 --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,5 @@ +Changelog +********* + +.. include:: ../CHANGELOG.md + :parser: myst_parser.sphinx_ diff --git a/docs/conf.py b/docs/conf.py index ba89cb1..8945cef 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,6 @@ "sphinx_design", "sphinxcontrib.bibtex", "sphinx_codeautolink", - "sphinx_remove_toctrees", ] # List of patterns, relative to source directory, that match files and @@ -73,6 +72,7 @@ html_theme = "pymc_sphinx_theme" html_theme_options = { "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"], + "search_bar_text": "Search within PyMC-BART...", "navbar_start": ["navbar-logo"], "icon_links": [ { @@ -80,17 +80,6 @@ "icon": "fa-brands fa-github", "name": "GitHub", }, - { - "url": "https://twitter.com/pymc_devs/", - "icon": "fa-brands fa-twitter", - "name": "Twitter", - }, - { - "url": "https://www.youtube.com/c/PyMCDevelopers", - "icon": "fa-brands fa-youtube", - "name": "YouTube", - }, - {"url": "https://discourse.pymc.io", "icon": "fa-brands fa-discourse", "name": "Discourse"}, ], } @@ -144,23 +133,6 @@ nb_execution_mode = "off" -remove_from_toctrees = [ - "BART/*", - "case_studies/*", - "causal_inference/*", - "diagnostics_and_criticism/*", - "gaussian_processes/*", - "generalized_linear_models/*", - "mixture_models/*", - "ode_models/*", - "howto/*", - "samplers/*", - "splines/*", - "survival_analysis/*", - "time_series/*", - "variational_inference/*", -] - # bibtex config bibtex_bibfiles = ["references.bib"] bibtex_default_style = "unsrt" diff --git a/docs/index.rst b/docs/index.rst index 4b1dd0e..e390c3c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ interpretation of those models and perform variable selection. Installation ============ -PyMC-BART requires a working Python interpreter (3.8+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. +PyMC-BART requires a working Python interpreter (3.11+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge. @@ -93,10 +93,12 @@ Contents :maxdepth: 2 examples - api_reference -Indices -======= +References +========== + +.. toctree:: + :maxdepth: 1 -* :ref:`genindex` -* :ref:`modindex` + api_reference + changelog diff --git a/env-dev.yml b/env-dev.yml new file mode 100644 index 0000000..375558b --- /dev/null +++ b/env-dev.yml @@ -0,0 +1,23 @@ +name: pymc-bart-dev +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.24.0 + - numba + - matplotlib + - numpy + - pytensor + # Development dependencies + - pytest>=4.4.0 + - pytest-cov>=2.6.1 + - click==8.0.4 + - pylint==2.17.4 + - pre-commit + - black + - isort + - flake8 + - pip + - pip: + - arviz-stats[xarray]>=0.6.0 + - -e . diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..3afdd9f --- /dev/null +++ b/env.yml @@ -0,0 +1,14 @@ +name: pymc-bart +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.24.0 + - numba + - matplotlib + - numpy + - pytensor + - pip + - pip: + - pymc-bart + - arviz-stats[xarray]>=0.6.0 diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 56088d7..0000000 --- a/mypy.ini +++ /dev/null @@ -1,15 +0,0 @@ -[mypy] -files = pymc_bart/*.py -plugins = numpy.typing.mypy_plugin - -[mypy-matplotlib.*] -ignore_missing_imports = True - -[mypy-numba.*] -ignore_missing_imports = True - -[mypy-pymc.*] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index c10b8f8..cfc1afc 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -17,11 +17,15 @@ from pymc_bart.pgbart import PGBART from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( + compute_variable_importance, + get_variable_inclusion, plot_convergence, - plot_dependence, plot_ice, plot_pdp, + plot_scatter_submodels, plot_variable_importance, + plot_variable_inclusion, + vi_to_kulprit, ) __all__ = [ @@ -30,13 +34,17 @@ "ContinuousSplitRule", "OneHotSplitRule", "SubsetSplitRule", + "compute_variable_importance", + "get_variable_inclusion", "plot_convergence", - "plot_dependence", "plot_ice", "plot_pdp", + "plot_scatter_submodels", "plot_variable_importance", + "plot_variable_inclusion", + "vi_to_kulprit", ] -__version__ = "0.7.0" +__version__ = "0.10.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 969baf4..233d33e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -16,7 +16,7 @@ import warnings from multiprocessing import Manager -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import numpy.typing as npt @@ -25,9 +25,10 @@ from pymc.distributions.distribution import Distribution, _support_point from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sharedvar import TensorSharedVariable +from pytensor.tensor.variable import TensorVariable from .split_rules import SplitRule -from .tree import Tree from .utils import TensorLike, _sample_posterior __all__ = ["BART"] @@ -37,24 +38,31 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - ndim_supp = 1 - ndims_params: List[int] = [2, 1, 0, 0, 0, 1] + signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" - _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = List[List[List[Tree]]] + _print_name: tuple[str, str] = ("BART", "\\operatorname{BART}") def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed - return dist_params[0].shape[:1] + idx = dist_params[0].ndim - 2 + return [dist_params[0].shape[idx]] @classmethod def rng_fn( # pylint: disable=W0237 - cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None + cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None ): - if not cls.all_trees: + if not size: + size = None + + if not hasattr(cls, "all_trees") or not cls.all_trees: + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): + Y = cls.Y.eval() + else: + Y = cls.Y + if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) + return np.full((size[0], Y.shape[0]), Y.mean()) else: - return np.full(cls.Y.shape[0], cls.Y.mean()) + return np.full(Y.shape[0], Y.mean()) else: if size is not None: shape = size[0] @@ -89,16 +97,13 @@ class BART(Distribution): beta : float Controls the prior probability over the number of leaves of the trees. Should be positive. - split_prior : Optional[List[float]], default None. + split_prior : Optional[list[float]], default None. List of positive numbers, one per column in input data. Defaults to None, all covariates have the same prior probability to be selected. - split_rules : Optional[List[SplitRule]], default None + split_rules : Optional[list[SplitRule]], default None List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - shape: : Optional[Tuple], default None - Specify the output shape. If shape is different from (len(X)) (the default), train a - separate tree for each value in other dimensions. separate_trees : Optional[bool], default False When training multiple trees (by setting a shape parameter), the default behavior is to learn a joint tree structure and only have different leaf values for each. @@ -125,8 +130,8 @@ def __new__( alpha: float = 0.95, beta: float = 2.0, response: str = "constant", - split_prior: Optional[npt.NDArray[np.float64]] = None, - split_rules: Optional[List[SplitRule]] = None, + split_prior: Optional[npt.NDArray] = None, + split_rules: Optional[list[SplitRule]] = None, separate_trees: Optional[bool] = False, **kwargs, ): @@ -135,8 +140,9 @@ def __new__( "Options linear and mix are experimental and still not well tested\n" + "Use with caution." ) + # Create a unique manager list for each BART instance manager = Manager() - cls.all_trees = manager.list() + instance_all_trees = manager.list() X, Y = preprocess_xy(X, Y) @@ -147,7 +153,7 @@ def __new__( (BARTRV,), { "name": "BART", - "all_trees": cls.all_trees, + "all_trees": instance_all_trees, # Instance-specific tree storage "inplace": False, "initval": Y.mean(), "X": X, @@ -169,7 +175,7 @@ def get_moment(rv, size, *rv_inputs): return cls.get_moment(rv, size, *rv_inputs) cls.rv_op = bart_op - params = [X, Y, m, alpha, beta, split_prior] + params = [X, Y, m, alpha, beta] return super().__new__(cls, name, *params, **kwargs) @classmethod @@ -196,9 +202,7 @@ def get_moment(cls, rv, size, *rv_inputs): return mean -def preprocess_xy( - X: TensorLike, Y: TensorLike -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 91a9beb..87bd36a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np import numpy.typing as npt +import pymc as pm +import pytensor.tensor as pt from numba import njit +from pymc.initial_point import PointType from pymc.model import Model, modelcontext from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared @@ -34,6 +37,7 @@ get_idx_left_child, get_idx_right_child, ) +from pymc_bart.utils import _encode_vi class ParticleTree: @@ -43,7 +47,7 @@ class ParticleTree: def __init__(self, tree: Tree): self.tree: Tree = tree.copy() - self.expansion_nodes: List[int] = [0] + self.expansion_nodes: list[int] = [0] self.log_weight: float = 0 def copy(self) -> "ParticleTree": @@ -114,23 +118,51 @@ class PGBART(ArrayStepShared): name = "pgbart" default_blocked = False generates_stats = True - stats_dtypes = [{"variable_inclusion": object, "tune": bool}] + stats_dtypes_shapes: dict[str, tuple[type, list]] = { + "variable_inclusion": (int, []), + "tune": (bool, []), + } - def __init__( # noqa: PLR0915 + def __init__( # noqa: PLR0912, PLR0915 self, - vars=None, # pylint: disable=redefined-builtin + vars: list[pm.Distribution] | None = None, num_particles: int = 10, - batch: Tuple[float, float] = (0.1, 0.1), + batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, - ): + initial_point: PointType | None = None, + compile_kwargs: dict | None = None, + **kwargs, # Accept additional kwargs for compound sampling + ) -> None: model = modelcontext(model) - initial_values = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) - value_bart = vars[0] + + if vars is None: + raise ValueError("Unable to find variables to sample") + + # Filter to only BART variables + bart_vars = [] + for var in vars: + rv = model.values_to_rvs.get(var) + if rv is not None and isinstance(rv.owner.op, BARTRV): + bart_vars.append(var) + + if not bart_vars: + raise ValueError("No BART variables found in the provided variables") + + if len(bart_vars) > 1: + raise ValueError( + "PGBART can only handle one BART variable at a time. " + "For multiple BART variables, PyMC will automatically create " + "separate PGBART samplers for each variable." + ) + + value_bart = bart_vars[0] self.bart = model.values_to_rvs[value_bart].owner.op if isinstance(self.bart.X, Variable): @@ -147,7 +179,7 @@ def __init__( # noqa: PLR0915 self.m = self.bart.m self.response = self.bart.response - shape = initial_values[value_bart.name].shape + shape = initial_point[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] @@ -214,20 +246,20 @@ def __init__( # noqa: PLR0915 self.num_particles = num_particles self.indices = list(range(1, num_particles)) - shared = make_shared_replacements(initial_values, vars, model) - self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) + shared = make_shared_replacements(initial_point, [value_bart], model) + self.likelihood_logp = logp(initial_point, [model.datalogp], [value_bart], shared) self.all_particles = [ [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) ] self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles]) self.lower = 0 self.iter = 0 - super().__init__(vars, shared) + super().__init__([value_bart], shared) def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - upper = min(self.lower + self.batch[~self.tune], self.m) + upper = min(self.lower + self.batch[not self.tune], self.m) tree_ids = range(self.lower, upper) self.lower = upper if upper < self.m else 0 @@ -304,10 +336,12 @@ def astep(self, _): if not self.tune: self.bart.all_trees.append(self.all_trees) + variable_inclusion = _encode_vi(variable_inclusion) + stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] - def normalize(self, particles: List[ParticleTree]) -> float: + def normalize(self, particles: list[ParticleTree]) -> float: """ Use softmax to get normalized_weights. """ @@ -318,30 +352,30 @@ def normalize(self, particles: List[ParticleTree]) -> float: return wei / wei.sum() def resample( - self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64] - ) -> List[ParticleTree]: + self, particles: list[ParticleTree], normalized_weights: npt.NDArray + ) -> list[ParticleTree]: """ Use systematic resample for all but the first particle Ensure particles are copied only if needed. """ new_indices = self.systematic(normalized_weights) + 1 - seen: List[int] = [] - new_particles: List[ParticleTree] = [] + seen: list[int] = [] + new_particles: list[ParticleTree] = [] for idx in new_indices: if idx in seen: new_particles.append(particles[idx].copy()) else: new_particles.append(particles[idx]) - seen.append(idx) + seen.append(int(idx)) particles[1:] = new_particles return particles def get_particle_tree( - self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64] - ) -> Tuple[ParticleTree, Tree]: + self, particles: list[ParticleTree], normalized_weights: npt.NDArray + ) -> tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree """ @@ -352,7 +386,7 @@ def get_particle_tree( return new_particle, new_particle.tree - def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]: + def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]: """ Systematic resampling. @@ -364,12 +398,12 @@ def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw return inverse_cdf(single_uniform, normalized_weights) - def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]: + def init_particles(self, tree_id: int, odim: int) -> list[ParticleTree]: """Initialize particles.""" p0: ParticleTree = self.all_particles[odim][tree_id] # The old tree does not grow so we update the weight only once self.update_weight(p0, odim) - particles: List[ParticleTree] = [p0] + particles: list[ParticleTree] = [p0] particles.extend(ParticleTree(self.a_tree) for _ in self.indices) return particles @@ -388,23 +422,30 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None: particle.log_weight = new_likelihood @staticmethod - def competence(var, has_grad): + def competence(var: pm.Distribution, has_grad: bool) -> Competence: """PGBART is only suitable for BART distributions.""" dist = getattr(var.owner, "op", None) if isinstance(dist, BARTRV): return Competence.IDEAL return Competence.INCOMPATIBLE + @staticmethod + def _make_update_stats_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in ("variable_inclusion", "tune")} + + return (update_stats,) + class RunningSd: """Welford's online algorithm for computing the variance/standard deviation""" - def __init__(self, shape: tuple) -> None: + def __init__(self, shape: tuple[int, ...]) -> None: self.count = 0 # number of data points self.mean = np.zeros(shape) # running mean self.m_2 = np.zeros(shape) # running second moment - def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: + def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]: self.count = self.count + 1 self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value) return fast_mean(std) @@ -413,10 +454,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray @njit def _update( count: int, - mean: npt.NDArray[np.float64], - m_2: npt.NDArray[np.float64], - new_value: npt.NDArray[np.float64], -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]: + mean: npt.NDArray, + m_2: npt.NDArray, + new_value: npt.NDArray, +) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]: delta = new_value - mean mean += delta / count delta2 = new_value - mean @@ -427,7 +468,7 @@ def _update( class SampleSplittingVariable: - def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None: + def __init__(self, alpha_vec: npt.NDArray) -> None: """ Sample splitting variables proportional to `alpha_vec`. @@ -436,7 +477,7 @@ def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None: """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - def rvs(self) -> Union[int, Tuple[int, float]]: + def rvs(self) -> Union[int, tuple[int, float]]: rnd: float = np.random.random() for i, val in self.enu: if rnd <= val: @@ -444,7 +485,7 @@ def rvs(self) -> Union[int, Tuple[int, float]]: return self.enu[-1] -def compute_prior_probability(alpha: int, beta: int) -> List[float]: +def compute_prior_probability(alpha: int, beta: int) -> list[float]: """ Calculate the probability of the node being a leaf node (1 - p(being split node)). @@ -457,7 +498,7 @@ def compute_prior_probability(alpha: int, beta: int) -> List[float]: ------- list with probabilities for leaf nodes """ - prior_leaf_prob: List[float] = [0] + prior_leaf_prob: list[float] = [0] depth = 0 while prior_leaf_prob[-1] < 0.9999: prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta)))) @@ -540,16 +581,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d def draw_leaf_value( - y_mu_pred: npt.NDArray[np.float64], - x_mu: npt.NDArray[np.float64], + y_mu_pred: npt.NDArray, + x_mu: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], + norm: npt.NDArray, shape: int, response: str, -) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """Draw Gaussian distributed leaf values.""" linear_params = None - mu_mean = np.empty(shape) + mu_mean: npt.NDArray if y_mu_pred.size == 0: return np.zeros(shape), linear_params @@ -564,7 +605,7 @@ def draw_leaf_value( @njit -def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: +def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]: """Use Numba to speed up the computation of the mean.""" if ari.ndim == 1: count = ari.shape[0] @@ -583,11 +624,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float @njit def fast_linear_fit( - x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], + x: npt.NDArray, + y: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], -) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]: + norm: npt.NDArray, +) -> tuple[npt.NDArray, list[npt.NDArray]]: n = len(x) y = y / m + np.expand_dims(norm, axis=1) @@ -671,17 +712,17 @@ def update(self): @njit def inverse_cdf( - single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64] + single_uniform: npt.NDArray, normalized_weights: npt.NDArray ) -> npt.NDArray[np.int_]: """ Inverse CDF algorithm for a finite distribution. Parameters ---------- - single_uniform: npt.NDArray[np.float64] + single_uniform: npt.NDArray Ordered points in [0,1] - normalized_weights: npt.NDArray[np.float64]) + normalized_weights: npt.NDArray) Normalized weights Returns @@ -704,7 +745,7 @@ def inverse_cdf( @njit -def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]: +def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray: """ Jitter duplicated values. """ @@ -720,12 +761,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray @njit -def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_: +def are_whole_number(array: npt.NDArray) -> np.bool_: """Check if all values in array are whole numbers""" return np.all(np.mod(array[~np.isnan(array)], 1) == 0) -def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin +def logp( + point, + out_vars: list[pm.Distribution], + vars: list[pm.Distribution], + shared: list[pt.TensorVariable], +): """Compile PyTensor function of the model and the input and output variables. Parameters diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 0e0a35c..61e5050 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from functools import lru_cache -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -27,21 +28,21 @@ class Node: Attributes ---------- - value : npt.NDArray[np.float64] + value : npt.NDArray idx_data_points : Optional[npt.NDArray[np.int_]] idx_split_variable : int - linear_params: Optional[List[float]] = None + linear_params: Optional[list[float]] = None """ __slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params" def __init__( self, - value: npt.NDArray[np.float64] = np.array([-1.0]), + value: npt.NDArray = np.array([-1.0]), nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> None: self.value = value self.nvalue = nvalue @@ -52,11 +53,11 @@ def __init__( @classmethod def new_leaf_node( cls, - value: npt.NDArray[np.float64], + value: npt.NDArray, nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> "Node": return cls( value=value, @@ -94,19 +95,19 @@ class Tree: Attributes ---------- - tree_structure : Dict[int, Node] + tree_structure : dict[int, Node] A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - output: Optional[npt.NDArray[np.float64]] + output: Optional[npt.NDArray] Array of shape number of observations, shape - split_rules : List[SplitRule] + split_rules : list[SplitRule] List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - idx_leaf_nodes : Optional[List[int]], by default None. + idx_leaf_nodes : Optional[list[int]], by default None. Array with the index of the leaf nodes of the tree. Parameters @@ -120,10 +121,10 @@ class Tree: def __init__( self, - tree_structure: Dict[int, Node], - output: npt.NDArray[np.float64], - split_rules: List[SplitRule], - idx_leaf_nodes: Optional[List[int]] = None, + tree_structure: dict[int, Node], + output: npt.NDArray, + split_rules: list[SplitRule], + idx_leaf_nodes: Optional[list[int]] = None, ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -133,11 +134,11 @@ def __init__( @classmethod def new_tree( cls, - leaf_node_value: npt.NDArray[np.float64], + leaf_node_value: npt.NDArray, idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, - split_rules: List[SplitRule], + split_rules: list[SplitRule], ) -> "Tree": return cls( tree_structure={ @@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None: self.set_node(index, node) def copy(self) -> "Tree": - tree: Dict[int, Node] = { + tree: dict[int, Node] = { k: Node( value=v.value, nvalue=v.nvalue, @@ -189,7 +190,7 @@ def grow_leaf_node( self, current_node: Node, selected_predictor: int, - split_value: npt.NDArray[np.float64], + split_value: npt.NDArray, index_leaf_node: int, ) -> None: current_node.value = split_value @@ -199,7 +200,7 @@ def grow_leaf_node( self.idx_leaf_nodes.remove(index_leaf_node) def trim(self) -> "Tree": - tree: Dict[int, Node] = { + tree: dict[int, Node] = { k: Node( value=v.value, nvalue=v.nvalue, @@ -221,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]: if node.is_split_node(): yield node.idx_split_variable - def _predict(self) -> npt.NDArray[np.float64]: + def _predict(self) -> npt.NDArray: output = self.output if self.idx_leaf_nodes is not None: @@ -232,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]: def predict( self, - x: npt.NDArray[np.float64], - excluded: Optional[List[int]] = None, + x: npt.NDArray, + excluded: Optional[list[int]] = None, shape: int = 1, - ) -> npt.NDArray[np.float64]: + ) -> npt.NDArray: """ Predict output of tree for an (un)observed point x. Parameters ---------- - x : npt.NDArray[np.float64] + x : npt.NDArray Unobserved point - excluded: Optional[List[int]] + excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns ------- - npt.NDArray[np.float64] + npt.NDArray Value of the leaf value where the unobserved point lies. """ if excluded is None: @@ -258,34 +259,36 @@ def predict( def _traverse_tree( self, - X: npt.NDArray[np.float64], - excluded: Optional[List[int]] = None, - shape: Union[int, Tuple[int, ...]] = 1, - ) -> npt.NDArray[np.float64]: + X: npt.NDArray, + excluded: Optional[list[int]] = None, + shape: Union[int, tuple[int, ...]] = 1, + ) -> npt.NDArray: """ Traverse the tree starting from the root node given an (un)observed point. Parameters ---------- - X : npt.NDArray[np.float64] + X : npt.NDArray (Un)observed point(s) node_index : int Index of the node to start the traversal from split_variable : int Index of the variable used to split the node - excluded: Optional[List[int]] + excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns ------- - npt.NDArray[np.float64] + npt.NDArray Leaf node value or mean of leaf node values """ x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1] nd_dims = (...,) + (None,) * len(x_shape) - stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state + stack: list[tuple[int, npt.NDArray, int]] = [ + (0, np.ones(x_shape), 0) + ] # (node_index, weight, idx_split_variable) initial state p_d = ( np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape) ) @@ -308,9 +311,19 @@ def _traverse_tree( ) if excluded is not None and idx_split_variable in excluded: prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue - stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable)) stack.append( - (right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable) + ( + left_node_index, + weights * prop_nvalue_left, + idx_split_variable, + ) + ) + stack.append( + ( + right_node_index, + weights * (1 - prop_nvalue_left), + idx_split_variable, + ) ) else: to_left = ( @@ -327,14 +340,14 @@ def _traverse_tree( return p_d def _traverse_leaf_values( - self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int + self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : List[npt.NDArray[np.float64]] + leaf_values : list[npt.NDArray] node_index : int """ node = self.get_node(node_index) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index a50f2d9..dfb5eac 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,32 +1,35 @@ +# pylint: disable=too-many-branches """Utility function for variable selection and bart interpretability.""" import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any, TypeVar -import arviz as az import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pymc as pm import pytensor.tensor as pt +from arviz_base import rcParams +from arviz_stats.base import array_stats from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm from .tree import Tree -TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable] +TensorLike = TypeVar("TensorLike", npt.NDArray, pt.TensorVariable) def _sample_posterior( - all_trees: List[List[Tree]], + all_trees: list[list[Tree]], X: TensorLike, rng: np.random.Generator, - size: Optional[Union[int, Tuple[int, ...]]] = None, - excluded: Optional[List[int]] = None, + size: int | tuple[int, ...] | None = None, + excluded: list[int] | None = None, shape: int = 1, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray: """ Generate samples from the BART-posterior. @@ -49,7 +52,7 @@ def _sample_posterior( X = X.eval() if size is None: - size_iter: Union[List, Tuple] = (1,) + size_iter: list | tuple = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -75,12 +78,12 @@ def _sample_posterior( def plot_convergence( - idata: az.InferenceData, - var_name: Optional[str] = None, + idata: Any, + var_name: str | None = None, kind: str = "ecdf", - figsize: Optional[Tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, ax=None, -) -> List[plt.Axes]: +) -> None: """ Plot convergence diagnostics. @@ -92,87 +95,44 @@ def plot_convergence( Name of the BART variable to plot. Defaults to None. kind : str Type of plot to display. Options are "ecdf" (default) and "kde". - figsize : Optional[Tuple[float, float]], by default None. + figsize : Optional[tuple[float, float]], by default None. Figure size. Defaults to None. ax : matplotlib axes Axes on which to plot. Defaults to None. Returns ------- - List[ax] : matplotlib axes + list[ax] : matplotlib axes """ - ess_threshold = idata["posterior"]["chain"].size * 100 - ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) - rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values) - - if figsize is None: - figsize = (10, 3) - - if kind == "ecdf": - kind_func: Callable[..., Any] = az.plot_ecdf - sharey = True - elif kind == "kde": - kind_func = az.plot_kde - sharey = False - - if ax is None: - _, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey) - - for idx, (essi, rhati) in enumerate(zip(ess, rhat)): - kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"}) - kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"}) - - ax[0].axvline(ess_threshold, color="0.7", ls="--") - # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile - # scaled by the sample size and use it as a threshold. - ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--") - - ax[0].set_xlabel("ESS") - ax[1].set_xlabel("R-hat") - if kind == "kde": - ax[0].set_yticks([]) - ax[1].set_yticks([]) - - return ax - - -def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument - """ - Partial dependence or individual conditional expectation plot. - """ - if kind == "pdp": - warnings.warn( - "This function has been deprecated. Use plot_pdp instead.", - FutureWarning, - ) - elif kind == "ice": - warnings.warn( - "This function has been deprecated. Use plot_ice instead.", - FutureWarning, - ) + warnings.warn( + "This function has been deprecated" + "Use az.plot_convergence_dist() instead." + "https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html", + FutureWarning, + ) def plot_ice( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, - func: Optional[Callable] = None, - centered: Optional[bool] = True, + X: npt.NDArray, + Y: npt.NDArray | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, + centered: bool | None = True, samples: int = 100, instances: int = 30, - random_seed: Optional[int] = None, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[Tuple[float, float]] = None, - smooth_kwargs: Optional[Dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, -) -> List[plt.Axes]: + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, +) -> list[plt.Axes]: """ Individual conditional expectation plot. @@ -180,13 +140,13 @@ def plot_ice( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. @@ -248,7 +208,7 @@ def identity(x): _, ) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances) idx_s = list(range(X.shape[0])) @@ -269,14 +229,13 @@ def identity(x): ) new_x = fake_X[:, var] - p_d = np.array(y_pred) - print(p_d.shape) + p_d = func(np.array(y_pred)) for s_i in range(shape): if centered: - p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None]) + p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None] else: - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] if var in var_discrete: axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean) axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha) @@ -299,25 +258,26 @@ def identity(x): def plot_pdp( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, List[float]]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, - func: Optional[Callable] = None, + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, samples: int = 200, - random_seed: Optional[int] = None, + ref_line: bool = True, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[Tuple[float, float]] = None, - smooth_kwargs: Optional[Dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, -) -> List[plt.Axes]: + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes = None, +) -> list[plt.Axes]: """ Partial dependence plot. @@ -325,28 +285,30 @@ def plot_pdp( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. xs_interval : str Method used to compute the values X used to evaluate the predicted function. "linear", evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified quantiles of X. "insample", the evaluation is done at the values of X. For discrete variables these options are ommited. - xs_values : Optional[Union[int, List[float]]], by default None. + xs_values : Optional[Union[int, list[float]]], by default None. Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. Ignored when ``xs_interval="insample"``. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 200 + ref_line : bool + If True a reference line is plotted at the mean of the partial dependence. Defaults to True. random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool @@ -398,25 +360,30 @@ def identity(x): xs_values, ) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete) - fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) + fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax) count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) + null_pd = [] for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) - p_d = _sample_posterior( - all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + p_d = func( + _sample_posterior( + all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + ) ) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] for s_i in range(shape): - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] + null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] - hdi = az.hdi(p_di)[idx_uni] + hdi = array_stats.hdi(p_di, prob=rcParams["stats.ci_prob"], axis=0)[idx_uni] axes[count].errorbar( new_x[idx_uni], y_means, @@ -426,11 +393,13 @@ def identity(x): ) axes[count].set_xticks(new_x[idx_uni]) else: - az.plot_hdi( + _plot_hdi( new_x, p_di, smooth=smooth, - fill_kwargs={"alpha": alpha, "color": color}, + alpha=alpha, + color=color, + smooth_kwargs=smooth_kwargs, ax=axes[count], ) if smooth: @@ -442,19 +411,24 @@ def identity(x): count += 1 + if ref_line: + ref_val = sum(null_pd) / len(null_pd) + for ax_ in np.ravel(axes): + ax_.axhline(ref_val, color="0.7", linestyle="--") + fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15) return axes -def _get_axes( +def _create_figure_axes( bartrv: Variable, - var_idx: List[int], + var_idx: list[int], grid: str = "long", sharey: bool = True, - figsize: Optional[Tuple[float, float]] = None, - ax: Optional[plt.Axes] = None, -) -> Tuple[plt.Figure, List[plt.Axes], int]: + figsize: tuple[float, float] | None = None, + ax: plt.Axes | None = None, +) -> tuple[plt.Figure, list[plt.Axes], int]: """ Create and return the figure and axes objects for plotting the variables. @@ -464,9 +438,9 @@ def _get_axes( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - var_idx : Optional[List[int]], by default None. + var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : Optional[List[int]], by default None. + var_discrete : Optional[list[int]], by default None. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number of @@ -481,7 +455,7 @@ def _get_axes( Returns ------- - Tuple[plt.Figure, List[plt.Axes], int] + tuple[plt.Figure, list[plt.Axes], int] A tuple containing the figure object, list of axes objects, and the shape value. """ if bartrv.ndim == 1: # type: ignore @@ -492,29 +466,8 @@ def _get_axes( n_plots = len(var_idx) * shape if ax is None: - if grid == "long": - fig, axes = plt.subplots(n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif grid == "wide": - fig, axes = plt.subplots(1, n_plots, sharey=sharey, figsize=figsize) - if n_plots == 1: - axes = [axes] - elif isinstance(grid, tuple): - grid_size = grid[0] * grid[1] - if n_plots > grid_size: - warnings.warn( - """The grid is smaller than the number of available variables to plot. - Automatically adjusting the grid size.""" - ) - grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) - - fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) - axes = np.ravel(axes) + fig, axes = _get_axes(grid, n_plots, False, sharey, figsize) - for i in range(n_plots, len(axes)): - fig.delaxes(axes[i]) - axes = axes[:n_plots] elif isinstance(ax, np.ndarray): axes = ax fig = ax[0].get_figure() @@ -525,22 +478,49 @@ def _get_axes( return fig, axes, shape +def _get_axes(grid, n_plots, sharex, sharey, figsize): + if grid == "long": + fig, axes = plt.subplots(n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif grid == "wide": + fig, axes = plt.subplots(1, n_plots, sharex=sharex, sharey=sharey, figsize=figsize) + if n_plots == 1: + axes = [axes] + elif isinstance(grid, tuple): + grid_size = grid[0] * grid[1] + if n_plots > grid_size: + warnings.warn( + """The grid is smaller than the number of available variables to plot. + Automatically adjusting the grid size.""" + ) + grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1]) + + fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize) + axes = np.ravel(axes) + + for i in range(n_plots, len(axes)): + fig.delaxes(axes[i]) + axes = axes[:n_plots] + return fig, axes + + def _prepare_plot_data( - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, List[float]]] = None, - var_idx: Optional[List[int]] = None, - var_discrete: Optional[List[int]] = None, -) -> Tuple[ - npt.NDArray[np.float64], - List[str], + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, +) -> tuple[ + npt.NDArray, + list[str], str, - List[int], - List[int], - List[int], + list[int], + list[int], + list[int], str, - Union[int, None, List[float]], + int | None | list[float], ]: """ Prepare data for plotting. @@ -619,10 +599,10 @@ def _prepare_plot_data( def _create_pdp_data( - X: npt.NDArray[np.float64], + X: npt.NDArray, xs_interval: str, - xs_values: Optional[Union[int, List[float]]] = None, -) -> npt.NDArray[np.float64]: + xs_values: int | list[float] | None = None, +) -> npt.NDArray: """ Create data for partial dependence plot. @@ -637,7 +617,7 @@ def _create_pdp_data( Returns ------- - npt.NDArray[np.float64] + npt.NDArray A 2D array for the fake_X data. """ if xs_interval == "insample": @@ -654,11 +634,11 @@ def _create_pdp_data( def _smooth_mean( - new_x: npt.NDArray[np.float64], - p_di: npt.NDArray[np.float64], - kind: str = "pdp", - smooth_kwargs: Optional[Dict[str, Any]] = None, -) -> Tuple[np.ndarray, np.ndarray]: + new_x: npt.NDArray, + p_di: npt.NDArray, + kind: str = "neutral", + smooth_kwargs: dict[str, Any] | None = None, +) -> tuple[np.ndarray, np.ndarray]: """ Smooth the mean data for plotting. @@ -670,12 +650,12 @@ def _smooth_mean( The distribution of partial dependence from which to comptue the smoothed mean. kind : str, optional The type of plot. Possible values are "pdp" or "ice". - smooth_kwargs : Optional[Dict[str, Any]], optional + smooth_kwargs : Optional[dict[str, Any]], optional Additional keyword arguments for the smoothing function. Defaults to None. Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] A tuple containing a grid for the x-axis data and the corresponding smoothed y-axis data. """ @@ -685,7 +665,10 @@ def _smooth_mean( smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(np.nanmin(new_x), np.nanmax(new_x), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 - if kind == "pdp": + + if kind == "neutral": + interp = griddata(new_x, p_di, x_data) + elif kind == "pdp": interp = griddata(new_x, p_di.mean(0), x_data) else: interp = griddata(new_x, p_di.T, x_data) @@ -693,110 +676,235 @@ def _smooth_mean( return x_data, y_data -def plot_variable_importance( # noqa: PLR0915 - idata: az.InferenceData, +def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None, to_kulprit=False): + """ + Get the normalized variable inclusion from BART model. + + Parameters + ---------- + idata : InferenceData + InferenceData with a variable "variable_inclusion" in ``sample_stats`` group + X : npt.NDArray + The covariate matrix. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. + bart_var_name : Optional[str] + The name of the BART variable in the model. Only needed if the model contains multiple + BART variables. + labels : Optional[list[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + to_kulprit : bool + If True, the function will return a list of list with the variables names. + This list can be passed as a path to Kulprit's project method. Defaults to False. + + Returns + ------- + VI_norm : npt.NDArray + Normalized variable inclusion. + labels : list[str] + List of the names of the covariates. + """ + n_vars = X.shape[1] + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None or bart_var_name is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + VIs = np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0) + VI_norm = VIs / VIs.sum() + idxs = np.argsort(VI_norm) + + indices = idxs[::-1] + n_vars = len(indices) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = list(X.columns) + + if labels is None: + labels = [str(i) for i in range(n_vars)] + + if to_kulprit: + return [labels[:idx] for idx in range(n_vars)] + else: + return VI_norm[indices], labels + + +def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): + """ + Plot normalized variable inclusion from BART model. + + Parameters + ---------- + idata : InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : npt.NDArray + The covariate matrix. + labels : Optional[list[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + figsize : tuple + Figure size. If None it will be defined automatically. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color: matplotlib valid color for VI + - marker: matplotlib valid marker for VI + - ls: matplotlib valid linestyle for the VI line + - rotation: float, rotation of the x-axis labels + ax : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ + if plot_kwargs is None: + plot_kwargs = {} + + VI_norm, labels = get_variable_inclusion(idata, X, labels) + n_vars = len(labels) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] + + ticks = np.arange(n_vars, dtype=int) + + if figsize is None: + figsize = (8, 3) + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + ax.axhline(1 / n_vars, color="0.5", linestyle="--") + ax.plot( + VI_norm, + color=plot_kwargs.get("color", "k"), + marker=plot_kwargs.get("marker", "o"), + ls=plot_kwargs.get("ls", "-"), + ) + + ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0)) + ax.set_ylim(0, 1) + + return ax + + +def compute_variable_importance( # noqa: PLR0915 PLR0912 + idata: Any, bartrv: Variable, - X: npt.NDArray[np.float64], - labels: Optional[List[str]] = None, + X: npt.NDArray, + model: "pm.Model | None" = None, method: str = "VI", - figsize: Optional[Tuple[float, float]] = None, + fixed: int = 0, samples: int = 50, - random_seed: Optional[int] = None, - plot_kwargs: Optional[Dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, -) -> Tuple[List[int], Union[List[plt.Axes], Any]]: + random_seed: int | None = None, +) -> dict[str, npt.NDArray]: """ Estimates variable importance from the BART-posterior. Parameters ---------- - idata: InferenceData - InferenceData containing a collection of BART_trees in sample_stats group + idata : InferenceData + InferenceData containing a "variable_inclusion" variable in the sample_stats group. bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - labels : Optional[List[str]] - List of the names of the covariates. If X is a DataFrame the names of the covariables will - be taken from it and this argument will be ignored. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. method : str - Method used to rank variables. Available options are "VI" (default) and "backward". + Method used to rank variables. Available options are "VI" (default), "backward" + and "backward_VI". The R squared will be computed following this ranking. "VI" counts how many times each variable is included in the posterior distribution of trees. "backward" uses a backward search based on the R squared. - VI requieres less computation time. - figsize : tuple - Figure size. If None it will be defined automatically. + "backward_VI" combines both methods with the backward search excluding + the ``fixed`` number of variables with the lowest variable inclusion. + "VI" is the fastest method, while "backward" is the slowest. + fixed : Optional[int] + Number of variables to fix in the backward search. Defaults to None. + Must be greater than 0 and less than the number of variables. + Ignored if method is "VI" or "backward". samples : int - Number of predictions used to compute correlation for subsets of variables. Defaults to 100 + Number of predictions used to compute correlation for subsets of variables. Defaults to 50 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. - plot_kwargs : dict - Additional keyword arguments for the plot. Defaults to None. - Valid keys are: - - color_r2: matplotlib valid color for error bars - - marker_r2: matplotlib valid marker for the mean R squared - - marker_fc_r2: matplotlib valid marker face color for the mean R squared - - ls_ref: matplotlib valid linestyle for the reference line - - color_ref: matplotlib valid color for the reference line - ax : axes - Matplotlib axes. Returns ------- - idxs: indexes of the covariates from higher to lower relative importance - axes: matplotlib axes + vi_results: dictionary """ + if method not in ["VI", "backward", "backward_VI"]: + raise ValueError("method must be 'VI', 'backward' or 'backward_VI'") + rng = np.random.default_rng(random_seed) all_trees = bartrv.owner.op.all_trees - - if plot_kwargs is None: - plot_kwargs = {} + bart_var_name = bartrv.name if bartrv.ndim == 1: # type: ignore shape = 1 else: shape = bartrv.eval().shape[0] + n_vars = X.shape[1] + if hasattr(X, "columns") and hasattr(X, "to_numpy"): labels = X.columns X = X.to_numpy() - - n_vars = X.shape[1] - - if figsize is None: - figsize = (8, 3) - - if ax is None: - _, ax = plt.subplots(1, 1, figsize=figsize) - - if labels is None: - labels_ary = np.arange(n_vars).astype(str) else: - labels_ary = np.array(labels) - - ticks = np.arange(n_vars, dtype=int) + labels = np.arange(n_vars).astype(str) + + r2_mean: npt.NDArray = np.zeros(n_vars) + r2_hdi: npt.NDArray = np.zeros((n_vars, 2)) + preds: npt.NDArray = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) + + if method == "backward_VI": + if fixed >= n_vars: + raise ValueError("fixed must be less than the number of variables") + elif fixed < 1: + raise ValueError("fixed must be greater than 0") + init = fixed + 1 + else: + fixed = 0 + init = 0 predicted_all = _sample_posterior( all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) - r_2_ref = np.array( - [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] - ) + if method in ["VI", "backward_VI"]: + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) - if method == "VI": - idxs = np.argsort( - idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values - ) - subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))] + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + idxs = np.argsort(np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0)) + subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))] subsets.append(None) # type: ignore - indices: List[int] = list(idxs[::-1]) + if method == "backward_VI": + subsets = subsets[-init:] + + indices: list[int] = list(idxs[::-1]) - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( all_trees=all_trees, @@ -810,20 +918,25 @@ def plot_variable_importance( # noqa: PLR0915 [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) - r2_hdi[idx] = az.hdi(r_2) - - elif method == "backward": - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) - - variables = set(range(n_vars)) - least_important_vars: List[int] = [] - indices = [] + r2_hdi[idx] = array_stats.hdi(r_2, prob=rcParams["stats.ci_prob"]) + preds[idx] = predicted_subset.squeeze() + + if method in ["backward", "backward_VI"]: + if method == "backward_VI": + least_important_vars: list[int] = indices[-fixed:] + r2_mean_vi = r2_mean[:init] + r2_hdi_vi = r2_hdi[:init] + preds_vi = preds[:init] + r2_mean = np.zeros(n_vars - fixed - 1) + r2_hdi = np.zeros((n_vars - fixed - 1, 2)) + preds = np.zeros((n_vars - fixed - 1, samples, bartrv.eval().shape[0])) + else: + least_important_vars = [] # Iterate over each variable to determine its contribution # least_important_vars tracks the variable with the lowest contribution - # at the current stage. One new varible is added at each iteration. - for i_var in range(n_vars): + # at the current stage. One new variable is added at each iteration. + for i_var in range(init, n_vars): # Generate all possible subsets by adding one variable at a time to # least_important_vars subsets = generate_sequences(n_vars, i_var, least_important_vars) @@ -851,30 +964,138 @@ def plot_variable_importance( # noqa: PLR0915 max_r_2 = mean_r_2 least_important_subset = subset r_2_without_least_important_vars = r_2 + least_important_samples = predicted_subset # Save values for plotting later - r2_mean[i_var] = max_r_2 - r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars) + r2_mean[i_var - init] = max_r_2 + r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars) + preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable - least_important_vars += least_important_subset + for var_i in least_important_subset: + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + # Add the remaining variables to the list of least important variables + for var_i in range(n_vars): + if var_i not in least_important_vars: + least_important_vars.append(var_i) + + if method == "backward_VI": + r2_mean = np.concatenate((r2_mean[::-1], r2_mean_vi)) + r2_hdi = np.concatenate((r2_hdi[::-1], r2_hdi_vi)) + preds = np.concatenate((preds[::-1], preds_vi)) + else: + r2_mean = r2_mean[::-1] + r2_hdi = r2_hdi[::-1] + preds = preds[::-1] + + indices = least_important_vars[::-1] + + labels = np.array( + ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + ) - # add index of removed variable - indices += list(set(least_important_subset) - set(indices)) + vi_results = { + "indices": np.asarray(indices), + "labels": labels, + "r2_mean": r2_mean, + "r2_hdi": r2_hdi, + "preds": preds, + "preds_all": predicted_all.squeeze(), + } + return vi_results - # add remaining index - indices += list(set(variables) - set(least_important_vars)) - indices = indices[::-1] - r2_mean = r2_mean[::-1] - r2_hdi = r2_hdi[::-1] +def vi_to_kulprit(vi_results: dict) -> list[list[str]]: + """ + Export variable importance results to Kulprit format. - new_labels = [ - "+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices]) - ] + Parameters + ---------- + vi_results : dict + Dictionary computed with `compute_variable_importance` + + Returns + ------- + list[list[str]] + A list of lists containing variable names for each submodel. + """ + clean_labels = [label.strip("+ ") for label in vi_results["labels"]] + return [clean_labels[:idx] for idx in range(len(clean_labels))] + + +def plot_variable_importance( + vi_results: dict, + submodels: list[int] | np.ndarray | tuple[int, ...] | None = None, + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, +): + """ + Estimates variable importance from the BART-posterior. + + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. + labels : Optional[list[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line + - rotation: float, rotation angle of the x-axis labels. Defaults to 0. + ax : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) + + indices = vi_results["indices"][submodels] + r2_mean = vi_results["r2_mean"][submodels] + r2_hdi = vi_results["r2_hdi"][submodels] + preds = vi_results["preds"][submodels] + preds_all = vi_results["preds_all"] + samples = preds.shape[1] + + n_vars = len(indices) + ticks = np.arange(n_vars, dtype=int) + + if plot_kwargs is None: + plot_kwargs = {} + + if figsize is None: + figsize = (8, 3) + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=figsize) + + if labels is None: + labels = vi_results["labels"][submodels] + + r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None) r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None) + ax.errorbar( ticks, r2_mean, @@ -890,20 +1111,137 @@ def plot_variable_importance( # noqa: PLR0915 ) ax.fill_between( [-0.5, n_vars - 0.5], - *az.hdi(r_2_ref), + *array_stats.hdi(r_2_ref, prob=rcParams["stats.ci_prob"]), alpha=0.1, color=plot_kwargs.get("color_ref", "grey"), ) ax.set_xticks( ticks, - new_labels, + labels, rotation=plot_kwargs.get("rotation", 0), ) ax.set_ylabel("R²", rotation=0, labelpad=12) ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) - return indices, ax + return ax + + +def plot_scatter_submodels( + vi_results: dict, + func: Callable | None = None, + submodels: list[int] | np.ndarray | None = None, + grid: str = "long", + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, +) -> list[plt.Axes]: + """ + Plot submodel's predictions against reference-model's predictions. + + Parameters + ---------- + vi_results : Dictionary + Dictionary computed with `compute_variable_importance` + func : Optional[Callable], by default None. + Arbitrary function to apply to the predictions. Defaults to the identity function. + submodels : Optional[Union[list[int], np.ndarray]] + List of the indices of the submodels to plot. Defaults to None, all variables are ploted. + The indices correspond to order computed by `compute_variable_importance`. + For example `submodels=[0,1]` will plot the two most important variables. + `submodels=[1,0]` is equivalent as values are sorted before use. + grid : str or tuple + How to arrange the subplots. Defaults to "long", one subplot below the other. + Other options are "wide", one subplot next to each other or a tuple indicating the number + of rows and columns. + labels : Optional[list[str]] + List of the names of the covariates. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - marker_scatter: matplotlib valid marker for the scatter plot + - color_scatter: matplotlib valid color for the scatter plot + - alpha_scatter: matplotlib valid alpha for the scatter plot + - color_ref: matplotlib valid color for the 45 degree line + - ls_ref: matplotlib valid linestyle for the reference line + axes : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ + if submodels is None: + submodels = np.sort(vi_results["indices"]) + else: + submodels = np.sort(submodels) + + indices = vi_results["indices"][submodels] + preds_sub = vi_results["preds"][submodels] + preds_all = vi_results["preds_all"] + + if labels is None: + labels = vi_results["labels"][submodels] + + # handle categorical regression case: + n_cats = None + if preds_all.ndim > 2: + n_cats = preds_all.shape[-1] + indices = np.tile(indices, n_cats) + + if ax is None: + _, ax = _get_axes(grid, len(indices), True, True, figsize) + + if plot_kwargs is None: + plot_kwargs = {} + + if func is not None: + preds_sub = func(preds_sub) + preds_all = func(preds_all) + + min_ = min(np.min(preds_sub), np.min(preds_all)) + max_ = max(np.max(preds_sub), np.max(preds_all)) + + # handle categorical regression case: + if n_cats is not None: + i = 0 + for cat in range(n_cats): + for pred_sub, x_label in zip(preds_sub, labels): + ax[i].plot( + pred_sub[..., cat], + preds_all[..., cat], + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", f"C{cat}"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}") + ax[i].axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) + i += 1 + else: + for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()): + axi.plot( + pred_sub, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + axi.set(xlabel=x_label, ylabel="ref model") + axi.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) + return ax def generate_sequences(n_vars, i_var, include): @@ -923,3 +1261,58 @@ def pearsonr2(A, B): am = A - np.mean(A) bm = B - np.mean(B) return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2)) + + +def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): + x = np.asarray(x) + y = np.asarray(y) + hdi_prob = rcParams["stats.ci_prob"] + hdi_data = array_stats.hdi(y, hdi_prob, axis=0) + if smooth: + if isinstance(x[0], np.datetime64): + raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.") + + x_data, y_data = _smooth_mean(x, hdi_data, smooth_kwargs=smooth_kwargs) + else: + idx = np.argsort(x) + x_data = x[idx] + y_data = hdi_data[idx] + + ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha) + return ax + + +def _decode_vi(n: int, length: int) -> list[int]: + """ + Decode the variable inclusion from the BART model. + """ + bits = bin(n)[2:] + vi_list: list[int] = [] + i = 0 + while len(vi_list) < length: + # Count prefix ones + prefix_len = 0 + while bits[i] == "1": + prefix_len += 1 + i += 1 + i += 1 # skip the '0' + b = bits[i : i + prefix_len] + vi_list.append(int(b, 2)) + i += prefix_len + return vi_list + + +def _encode_vi(vec: npt.NDArray) -> int: + """ + Encode variable inclusion vector into a single integer. + + The encoding is done by converting each element of the vector into a binary string, + where each element contributes a prefix of '1's followed by a '0' and its binary representation. + The final result is the integer representation of the concatenated binary string. + """ + bits = "" + for x in vec: + b = bin(x)[2:] + prefix = "1" * len(b) + "0" + bits += prefix + b + return int(bits, 2) diff --git a/pyproject.toml b/pyproject.toml index 165ed67..2afa2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,16 +8,17 @@ line-length = 100 [tool.ruff.lint] select = ["E", "F", "I", "PL", "UP", "W"] -ignore-init-module-imports = true ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. + "PLR0913", # Too many arguments in function definition + "PLC0415", # import should be at the top-level ] [tool.ruff.lint.pylint] max-args = 19 max-branches = 15 -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "docs/conf.py" = ["E501", "F541"] "tests/test_*.py" = ["F841"] @@ -32,3 +33,19 @@ exclude_lines = [ isort = 1 black = 1 pyupgrade = 1 + + +[tool.mypy] +files = "pymc_bart/*.py" + +[tool.mypy-matplotlib] +ignore_missing_imports = true + +[tool.mypy-numba] +ignore_missing_imports = true + +[tool.mypy-pymc] +ignore_missing_imports = true + +[tool.mypy-scipy] +ignore_missing_imports = true diff --git a/requirements-docs.txt b/requirements-docs.txt index 5074a06..214c399 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,8 +1,6 @@ myst-nb -sphinx==5.0.2 # see https://github.com/pymc-devs/pymc-examples/issues/409 -git+https://github.com/pymc-devs/pymc-sphinx-theme +sphinx +pymc-sphinx-theme>=0.16 sphinxcontrib-bibtex -nbsphinx sphinx_design sphinx_codeautolink -sphinx_remove_toctrees diff --git a/requirements.txt b/requirements.txt index e741cef..24d156b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pymc<=5.16.2 -arviz>=0.18.0 +pymc>=5.24.0 +arviz-stats[xarray]>=0.6.0 numba matplotlib -numpy +numpy>=2.0 diff --git a/setup.py b/setup.py index e934ae2..0ae76b2 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,9 @@ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", diff --git a/tests/test_bart.py b/tests/test_bart.py index dfbd86f..f446cd4 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -1,18 +1,19 @@ import numpy as np import pymc as pm import pytest -from numpy.testing import assert_almost_equal, assert_array_equal +from numpy.testing import assert_almost_equal from pymc.initial_point import make_initial_point_fn -from pymc.logprob.basic import joint_logp +from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb +from pymc_bart.utils import _decode_vi def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, return_transformed=False, - default_strategy="moment", + default_strategy="support_point", ) moment = fn(0)["x"] expected = np.asarray(expected) @@ -27,7 +28,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): if check_finite_logp: logp_moment = ( - joint_logp( + transformed_conditional_logp( (model["x"],), rvs_to_values={model["x"]: pm.math.constant(moment)}, rvs_to_transforms={}, @@ -52,14 +53,12 @@ def test_bart_vi(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - var_imp = ( - idata.sample_stats["variable_inclusion"] - .stack(samples=("chain", "draw")) - .mean("samples") - ) - var_imp /= var_imp.sum() + pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(tune=200, draws=200, random_seed=3415) + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + var_imp = np.array([_decode_vi(val, 3) for val in vi_vals]).sum(axis=0) + + var_imp = var_imp / var_imp.sum() assert var_imp[0] > var_imp[1:].sum() assert_almost_equal(var_imp.sum(), 1) @@ -77,8 +76,8 @@ def test_missing_data(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) + pm.Normal("y", mu, sigma, observed=Y) + pm.sample(tune=100, draws=100, chains=1, random_seed=3415) @pytest.mark.parametrize( @@ -91,7 +90,7 @@ def test_shared_variable(response): Y = np.random.normal(0, 1, size=50) with pm.Model() as model: - data_X = pm.MutableData("data_X", X) + data_X = pm.Data("data_X", X) mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) @@ -116,94 +115,13 @@ def test_shape(response): with pm.Model() as model: w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=50, draws=10, random_seed=3415) assert model.initial_point()["w"].shape == (2, 250) assert idata.posterior.coords["w_dim_0"].data.size == 2 assert idata.posterior.coords["w_dim_1"].data.size == 250 -class TestUtils: - X_norm = np.random.normal(0, 1, size=(50, 2)) - X_binom = np.random.binomial(1, 0.5, size=(50, 1)) - X = np.hstack([X_norm, X_binom]) - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - - def test_sample_posterior(self): - all_trees = self.mu.owner.op.all_trees - rng = np.random.default_rng(3) - pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) - rng = np.random.default_rng(3) - pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) - - assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50, 1) - assert pred_first.shape == (1, 10, 1) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "var_discrete": [3], - }, - {"instances": 2}, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_ice(self, kwargs): - pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "xs_interval": "quantiles", - "xs_values": [0.25, 0.5, 0.75], - "var_discrete": [3], - }, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_pdp(self, kwargs): - pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, - ], - ) - def test_vi(self, kwargs): - pmb.plot_variable_importance(self.idata, X=self.X, bartrv=self.mu, **kwargs) - - def test_pdp_pandas_labels(self): - pd = pytest.importorskip("pandas") - - X_names = ["norm1", "norm2", "binom"] - X_pd = pd.DataFrame(self.X, columns=X_names) - Y_pd = pd.Series(self.Y, name="response") - axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) - - figure = axes[0].figure - assert figure.texts[0].get_text() == "Partial response" - assert_array_equal([ax.get_xlabel() for ax in axes], X_names) - - @pytest.mark.parametrize( "size, expected", [ @@ -243,8 +161,88 @@ def test_categorical_model(separate_trees, split_rule): separate_trees=separate_trees, ) y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y) - idata = pm.sample(random_seed=3415, tune=300, draws=300) - idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True) + idata = pm.sample(tune=300, draws=300, random_seed=3415) + idata = pm.sample_posterior_predictive( + idata, predictions=True, extend_inferencedata=True, random_seed=3415 + ) # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() + assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3) + + +def test_multiple_bart_variables(): + """Test that multiple BART variables can coexist in a single model.""" + X1 = np.random.normal(0, 1, size=(50, 2)) + X2 = np.random.normal(0, 1, size=(50, 3)) + Y = np.random.normal(0, 1, size=50) + + # Create correlated responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=50) + Y2 = X2[:, 0] + X2[:, 1] + np.random.normal(0, 0.1, size=50) + + with pm.Model() as model: + # Two separate BART variables with different covariates + mu1 = pmb.BART("mu1", X1, Y1, m=5) + mu2 = pmb.BART("mu2", X2, Y2, m=5) + + # Combined model + sigma = pm.HalfNormal("sigma", 1) + pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Sample with automatic assignment of BART samplers + idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415) + + # Verify both BART variables have their own tree collections + assert hasattr(mu1.owner.op, "all_trees") + assert hasattr(mu2.owner.op, "all_trees") + + # Verify trees are stored separately (different object references) + assert mu1.owner.op.all_trees is not mu2.owner.op.all_trees + + # Verify sampling worked + assert idata.posterior["mu1"].shape == (1, 50, 50) + assert idata.posterior["mu2"].shape == (1, 50, 50) + + vi_results = pmb.compute_variable_importance(idata, mu1, X1, model=model) + assert vi_results["labels"].shape == (2,) + assert vi_results["preds"].shape == (2, 50, 50) + assert vi_results["preds_all"].shape == (50, 50) + + vi_tuple = pmb.get_variable_inclusion(idata, X1, model=model, bart_var_name="mu1") + assert vi_tuple[0].shape == (2,) + assert len(vi_tuple[1]) == 2 + assert isinstance(vi_tuple[1][0], str) + + +def test_multiple_bart_variables_manual_step(): + """Test that multiple BART variables work with manually assigned PGBART samplers.""" + X1 = np.random.normal(0, 1, size=(30, 2)) + X2 = np.random.normal(0, 1, size=(30, 2)) + Y = np.random.normal(0, 1, size=30) + + # Create simple responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=30) + Y2 = X2[:, 1] + np.random.normal(0, 0.1, size=30) + + with pm.Model() as model: + # Two separate BART variables + mu1 = pmb.BART("mu1", X1, Y1, m=3) + mu2 = pmb.BART("mu2", X2, Y2, m=3) + + # Non-BART variable + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Manually create PGBART samplers for each BART variable + step1 = pmb.PGBART([mu1], num_particles=5) + step2 = pmb.PGBART([mu2], num_particles=5) + + # Sample with manual step assignment + idata = pm.sample(tune=20, draws=20, chains=1, step=[step1, step2], random_seed=3415) + + # Verify both variables were sampled + assert "mu1" in idata.posterior + assert "mu2" in idata.posterior + assert idata.posterior["mu1"].shape == (1, 20, 30) + assert idata.posterior["mu2"].shape == (1, 20, 30) diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 4cf4188..5a1d35e 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -74,13 +74,13 @@ def test_discrete_uniform(): def test_normal_sampler(): normal = NormalSampler(2, shape=1) samples = np.array([normal.rvs() for i in range(100000)]) - np.testing.assert_almost_equal(samples.mean(), 0, decimal=2) - np.testing.assert_almost_equal(samples.std(), 2, decimal=2) + np.testing.assert_almost_equal(samples.mean(), 0, decimal=1) + np.testing.assert_almost_equal(samples.std(), 2, decimal=1) normal = NormalSampler(2, shape=2) samples = np.array([normal.rvs() for i in range(100000)]) - np.testing.assert_almost_equal(samples.mean(0), [0, 0], decimal=2) - np.testing.assert_almost_equal(samples.std(0), [2, 2], decimal=2) + np.testing.assert_almost_equal(samples.mean(0), [0, 0], decimal=1) + np.testing.assert_almost_equal(samples.std(0), [2, 2], decimal=1) def test_uniform_sampler(): diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ed85af7 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,111 @@ +import numpy as np +import pymc as pm +import pytest +from numpy.testing import assert_almost_equal, assert_array_equal + +import pymc_bart as pmb + + +class TestUtils: + X_norm = np.random.normal(0, 1, size=(50, 2)) + X_binom = np.random.binomial(1, 0.5, size=(50, 1)) + X = np.hstack([X_norm, X_binom]) + Y = np.random.normal(0, 1, size=50) + + with pm.Model() as model: + mu = pmb.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(tune=200, draws=200, random_seed=3415) + + def test_sample_posterior(self): + all_trees = self.mu.owner.op.all_trees + rng = np.random.default_rng(3) + pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) + rng = np.random.default_rng(3) + pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) + + assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) + assert pred_all.shape == (2, 50, 1) + assert pred_first.shape == (1, 10, 1) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "var_discrete": [3], + }, + {"instances": 2}, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_ice(self, kwargs): + pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "xs_interval": "quantiles", + "xs_values": [0.25, 0.5, 0.75], + "var_discrete": [3], + }, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_pdp(self, kwargs): + pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {"samples": 50}, + {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, + ], + ) + def test_vi(self, kwargs): + samples = kwargs.pop("samples") + vi_results = pmb.compute_variable_importance( + self.idata, bartrv=self.mu, X=self.X, samples=samples + ) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) + + user_terms = pmb.vi_to_kulprit(vi_results) + assert len(user_terms) == 3 + assert all("+" not in term for terms in user_terms[1:] for term in terms) + + def test_pdp_pandas_labels(self): + pd = pytest.importorskip("pandas") + + X_names = ["norm1", "norm2", "binom"] + X_pd = pd.DataFrame(self.X, columns=X_names) + Y_pd = pd.Series(self.Y, name="response") + axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) + + figure = axes[0].figure + assert figure.texts[0].get_text() == "Partial response" + assert_array_equal([ax.get_xlabel() for ax in axes], X_names) + + +def test_encoder_decoder(): + """Test that the encoder-decoder works correctly.""" + test_cases = [ + np.zeros(3, dtype=int), + np.ones(10, dtype=int), + np.array([4, 0, 1, 0, 2, 0, 3, 0, 0, 0]), + np.array([100, 50, 0, 1]), + np.array([1, 2, 4, 8, 16]), + ] + for case in test_cases: + encoded = pmb.utils._encode_vi(case) + decoded = pmb.utils._decode_vi(encoded, len(case)) + assert np.array_equal(decoded, case)