From 321df7a610160b4a20aa9a31f44dab7585cb6a61 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 5 Sep 2024 12:42:37 -0300 Subject: [PATCH] move xlabel_angle to plot_kwargs --- pymc_bart/utils.py | 76 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e506581..890216c 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -8,10 +8,11 @@ import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm, pearsonr +from scipy.stats import norm from .tree import Tree @@ -699,9 +700,9 @@ def plot_variable_importance( # noqa: PLR0915 labels: Optional[List[str]] = None, method: str = "VI", figsize: Optional[Tuple[float, float]] = None, - xlabel_angle: float = 0, - samples: int = 100, + samples: int = 50, random_seed: Optional[int] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ) -> Tuple[List[int], Union[List[plt.Axes], Any]]: """ @@ -733,6 +734,15 @@ def plot_variable_importance( # noqa: PLR0915 Number of predictions used to compute correlation for subsets of variables. Defaults to 100 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line + - rotation: float, rotation angle of the x-axis labels. Defaults to 0. ax : axes Matplotlib axes. @@ -745,14 +755,17 @@ def plot_variable_importance( # noqa: PLR0915 all_trees = bartrv.owner.op.all_trees + if plot_kwargs is None: + plot_kwargs = {} + if bartrv.ndim == 1: # type: ignore shape = 1 else: shape = bartrv.eval().shape[0] - if hasattr(X, "columns") and hasattr(X, "values"): + if hasattr(X, "columns") and hasattr(X, "to_numpy"): labels = X.columns - X = X.values + X = X.to_numpy() n_vars = X.shape[1] @@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915 all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) + r_2_ref = np.array( + [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] + ) + if method == "VI": idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values @@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915 shape=shape, ) r_2 = np.array( - [ - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2 - for j in range(samples) - ] + [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) @@ -812,7 +826,7 @@ def plot_variable_importance( # noqa: PLR0915 # Iterate over each variable to determine its contribution # least_important_vars tracks the variable with the lowest contribution - # at the current stage. One new varible is added at each iteration. + # at the current stage. One new variable is added at each iteration. for i_var in range(n_vars): # Generate all possible subsets by adding one variable at a time to # least_important_vars @@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915 # Calculate Pearson correlation for each sample and find the mean r_2 = np.zeros(samples) for j in range(samples): - r_2[j] = ( - (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]) - ** 2 - ) + r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j]) mean_r_2 = np.mean(r_2, dtype=float) # Identify the least important combination of variables # based on the maximum mean squared Pearson correlation @@ -872,17 +883,31 @@ def plot_variable_importance( # noqa: PLR0915 ticks, r2_mean, np.array((r2_yerr_min, r2_yerr_max)), - color="C0", + color=plot_kwargs.get("color_r2", "k"), + fmt=plot_kwargs.get("marker_r2", "o"), + mfc=plot_kwargs.get("marker_fc_r2", "white"), ) - ax.axhline(r2_mean[-1], ls="--", color="0.5") - ax.set_xticks(ticks, new_labels, rotation=xlabel_angle) + ax.axhline( + np.mean(r_2_ref), + ls=plot_kwargs.get("ls_ref", "--"), + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.fill_between( + [-0.5, n_vars - 0.5], + *az.hdi(r_2_ref), + alpha=0.1, + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.set_xticks(ticks, + new_labels, + rotation=plot_kwargs.get("rotation", 0), + ) ax.set_ylabel("R²", rotation=0, labelpad=12) ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) return indices, ax - def generate_sequences(n_vars, i_var, include): """Generate combinations of variables""" if i_var: @@ -890,3 +915,18 @@ def generate_sequences(n_vars, i_var, include): else: sequences = [()] return sequences + + +@jit(nopython=True) +def pearsonr2(A, B): + """Compute the squared Pearson correlation coefficient""" + A = A.flatten() + B = B.flatten() + am = A - np.mean(A) + bm = B - np.mean(B) + return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2)) + +@jit(nopython=True) +def wasserstein_distance(A, B): + """Compute the Conditional Wasserstein-2 distance between two arrays""" + return np.sqrt(np.mean((A - B) ** 2)) \ No newline at end of file