From 74591beb2aa2cc7ae5a1af47b6cdb92e0fc4cf37 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 5 Sep 2024 12:56:18 -0300 Subject: [PATCH] move x_angle to plot_kwargs --- pymc_bart/utils.py | 77 +++++++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e506581..a50f2d9 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -8,10 +8,11 @@ import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm, pearsonr +from scipy.stats import norm from .tree import Tree @@ -546,7 +547,7 @@ def _prepare_plot_data( Parameters ---------- - X : PyTensor Variable, Pandas DataFrame or Numpy array + X : PyTensor Variable, Pandas DataFrame, Polars DataFrame or Numpy array Input data. Y : array-like Target data. @@ -585,9 +586,9 @@ def _prepare_plot_data( if isinstance(X, Variable): X = X.eval() - if hasattr(X, "columns") and hasattr(X, "values"): + if hasattr(X, "columns") and hasattr(X, "to_numpy"): x_names = list(X.columns) - X = X.values + X = X.to_numpy() else: x_names = [] @@ -699,9 +700,9 @@ def plot_variable_importance( # noqa: PLR0915 labels: Optional[List[str]] = None, method: str = "VI", figsize: Optional[Tuple[float, float]] = None, - xlabel_angle: float = 0, - samples: int = 100, + samples: int = 50, random_seed: Optional[int] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ) -> Tuple[List[int], Union[List[plt.Axes], Any]]: """ @@ -726,13 +727,18 @@ def plot_variable_importance( # noqa: PLR0915 VI requieres less computation time. figsize : tuple Figure size. If None it will be defined automatically. - xlabel_angle : float - rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for - long labels and/or many variables. samples : int Number of predictions used to compute correlation for subsets of variables. Defaults to 100 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line ax : axes Matplotlib axes. @@ -745,14 +751,17 @@ def plot_variable_importance( # noqa: PLR0915 all_trees = bartrv.owner.op.all_trees + if plot_kwargs is None: + plot_kwargs = {} + if bartrv.ndim == 1: # type: ignore shape = 1 else: shape = bartrv.eval().shape[0] - if hasattr(X, "columns") and hasattr(X, "values"): + if hasattr(X, "columns") and hasattr(X, "to_numpy"): labels = X.columns - X = X.values + X = X.to_numpy() n_vars = X.shape[1] @@ -773,6 +782,10 @@ def plot_variable_importance( # noqa: PLR0915 all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) + r_2_ref = np.array( + [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] + ) + if method == "VI": idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values @@ -794,10 +807,7 @@ def plot_variable_importance( # noqa: PLR0915 shape=shape, ) r_2 = np.array( - [ - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2 - for j in range(samples) - ] + [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) @@ -833,10 +843,7 @@ def plot_variable_importance( # noqa: PLR0915 # Calculate Pearson correlation for each sample and find the mean r_2 = np.zeros(samples) for j in range(samples): - r_2[j] = ( - (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]) - ** 2 - ) + r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j]) mean_r_2 = np.mean(r_2, dtype=float) # Identify the least important combination of variables # based on the maximum mean squared Pearson correlation @@ -872,10 +879,26 @@ def plot_variable_importance( # noqa: PLR0915 ticks, r2_mean, np.array((r2_yerr_min, r2_yerr_max)), - color="C0", + color=plot_kwargs.get("color_r2", "k"), + fmt=plot_kwargs.get("marker_r2", "o"), + mfc=plot_kwargs.get("marker_fc_r2", "white"), + ) + ax.axhline( + np.mean(r_2_ref), + ls=plot_kwargs.get("ls_ref", "--"), + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.fill_between( + [-0.5, n_vars - 0.5], + *az.hdi(r_2_ref), + alpha=0.1, + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.set_xticks( + ticks, + new_labels, + rotation=plot_kwargs.get("rotation", 0), ) - ax.axhline(r2_mean[-1], ls="--", color="0.5") - ax.set_xticks(ticks, new_labels, rotation=xlabel_angle) ax.set_ylabel("R²", rotation=0, labelpad=12) ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) @@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include): else: sequences = [()] return sequences + + +@jit(nopython=True) +def pearsonr2(A, B): + """Compute the squared Pearson correlation coefficient""" + A = A.flatten() + B = B.flatten() + am = A - np.mean(A) + bm = B - np.mean(B) + return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))