Skip to content
26 changes: 11 additions & 15 deletions examples/miscellaneous/plot_partial_dependence_visualization_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,16 @@
ax1.legend()
ax2.legend()

# %%
# `tree_disp.axes_` is a numpy array container the axes used to draw the
# partial dependence plots. This can be passed to `mlp_disp` to have the same
# affect of drawing the plots on top of each other. Furthermore, the
# `mlp_disp.figure_` stores the figure, which allows for resizing the figure
# after calling `plot`. In this case `tree_disp.axes_` has two dimensions, thus
##############################################################################
# A single axes object can can be pass to `mlp_disp.plot` and `tree_disp.plot`
# to have the same affect of drawing the plots on the same axes.
# In this case `tree_disp.axes_` has two dimensions, thus
# `plot` will only show the y label and y ticks on the left most plot.

tree_disp.plot(line_kw={"label": "Decision Tree"})
fig, ax = plt.subplots(figsize=(10, 6))
tree_disp.plot(line_kw={"label": "Decision Tree"}, ax=ax)
mlp_disp.plot(line_kw={"label": "Multi-layer Perceptron", "color": "red"},
ax=tree_disp.axes_)
tree_disp.figure_.set_size_inches(10, 6)
ax=ax)
tree_disp.axes_[0, 0].legend()
tree_disp.axes_[0, 1].legend()
plt.show()
Expand All @@ -126,10 +124,8 @@
# Plotting partial dependence for one feature
# ===========================================
#
# Here, we plot the partial dependence curves for a single feature, "age", on
# the same axes. In this case, `tree_disp.axes_` is passed into the second
# plot function.
tree_disp = plot_partial_dependence(tree, X, ["age"])
mlp_disp = plot_partial_dependence(mlp, X, ["age"],
ax=tree_disp.axes_,
# Here, we plot the partial dependence curves for a single feature, "LSTAT", on
# the same axes.
tree_disp = plot_partial_dependence(tree, X, ["age"], ax=ax)
mlp_disp = plot_partial_dependence(mlp, X, ["age"], ax=ax,
line_kw={"color": "red"})
23 changes: 14 additions & 9 deletions sklearn/inspection/_plot/partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from ...utils import check_matplotlib_support # noqa
from ...utils import _safe_indexing
from ...utils.validation import _deprecate_positional_args
from ...utils._plot import _check_axes_has_been_used
from ...utils._plot import _check_axes_has_display_object


@_deprecate_positional_args
Expand Down Expand Up @@ -588,14 +590,17 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
n_sampled = self._get_sample_count(n_instances) + 1

if isinstance(ax, plt.Axes):
# If ax was set off, it has most likely been set to off
# by a previous call to plot.
if not ax.axison:
raise ValueError("The ax was already used in another plot "
"function, please set ax=display.axes_ "
"instead")

ax.set_axis_off()
disp_obj = _check_axes_has_display_object(self, ax)

# axes was used by another PartialDependenceDisplay for plotting
if disp_obj != self:
self.bounding_ax_ = disp_obj.bounding_ax_
# ax is set to be an array-like
ax = disp_obj.axes_

if isinstance(ax, plt.Axes):
_check_axes_has_been_used(ax)

self.bounding_ax_ = ax
self.figure_ = ax.figure

Expand Down Expand Up @@ -628,7 +633,7 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
else:
n_cols = None

self.bounding_ax_ = None
self.bounding_ax_ = getattr(self, "bounding_ax_", None)
self.figure_ = ax.ravel()[0].figure
self.axes_ = ax
if self.kind == 'average':
Expand Down
92 changes: 74 additions & 18 deletions sklearn/inspection/_plot/tests/test_plot_partial_dependence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import weakref
from scipy.stats.mstats import mquantiles

import pytest
from numpy.testing import assert_allclose
from numpy.testing import assert_array_equal

from sklearn.datasets import load_diabetes
from sklearn.datasets import load_iris
Expand All @@ -11,6 +13,7 @@
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LinearRegression
from sklearn.utils._testing import _convert_container
from sklearn.utils._plot import _SKLEARN_AX_DISP_OBJ_REF_KEY

from sklearn.inspection import plot_partial_dependence

Expand Down Expand Up @@ -298,34 +301,87 @@ def test_plot_partial_dependence_incorrent_num_axes(pyplot, clf_diabetes,
disp.plot(ax=ax_format)


def test_plot_partial_dependence_with_used_axes(pyplot, clf_diabetes,
diabetes):
# When the axes was drawn on plot_partial_dependence should fail
grid_resolution = 5
fig, ax = pyplot.subplots()
ax.plot([0, 1, 2], [1, 2, 3])

msg = "The ax was already used in a matplotlib plot function"
with pytest.raises(ValueError, match=msg):
plot_partial_dependence(clf_diabetes, diabetes.data, ['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names, ax=ax)


def test_plot_partial_dependence_used_by_another_display_obj(
pyplot, clf_diabetes, diabetes):
fig, ax = pyplot.subplots()

class DisplayMock:
pass

obj = DisplayMock()
setattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY, weakref.ref(obj))

msg = ("The ax was already used by another display object which is not "
"an instance of PartialDependenceDisplay")
with pytest.raises(ValueError, match=msg):
plot_partial_dependence(clf_diabetes, diabetes.data, ['age', 'bmi'],
grid_resolution=5,
feature_names=diabetes.feature_names, ax=ax)


@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
def test_plot_partial_dependence_with_same_axes(pyplot, clf_diabetes,
diabetes):
# The first call to plot_partial_dependence will create two new axes to
# place in the space of the passed in axes, which results in a total of
# three axes in the figure.
# Currently the API does not allow for the second call to
# plot_partial_dependence to use the same axes again, because it will
# create two new axes in the space resulting in five axes. To get the
# expected behavior one needs to pass the generated axes into the second
# call:
# disp1 = plot_partial_dependence(...)
# disp2 = plot_partial_dependence(..., ax=disp1.axes_)
# The second call will plot on the axes created by the first call to
# plot_partial_dependence

grid_resolution = 25
grid_resolution = 5
fig, ax = pyplot.subplots()
plot_partial_dependence(clf_diabetes, diabetes.data, ['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names, ax=ax)
disp1 = plot_partial_dependence(clf_diabetes, diabetes.data,
['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names,
ax=ax)

msg = ("The ax was already used in another plot function, please set "
"ax=display.axes_ instead")
axs = fig.get_axes()
assert len(axs) == 3

with pytest.raises(ValueError, match=msg):
plot_partial_dependence(clf_diabetes, diabetes.data,
['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names, ax=ax)
disp2 = plot_partial_dependence(clf_diabetes, diabetes.data,
['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names,
ax=ax)

axs = fig.get_axes()
assert len(axs) == 3

assert_array_equal(disp1.axes_, disp2.axes_)
assert disp1.figure_ == disp2.figure_
assert disp1.bounding_ax_ == disp2.bounding_ax_


def test_plot_partial_dependence_with_weak_ref(pyplot,
clf_diabetes, diabetes):
# When original display object is deleted, the weakref from the axes
# is set to None
grid_resolution = 5
fig, ax = pyplot.subplots()
disp = plot_partial_dependence(clf_diabetes, diabetes.data,
['age', 'bmi'],
grid_resolution=grid_resolution,
feature_names=diabetes.feature_names, ax=ax)

display_ref = getattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY)
assert isinstance(display_ref(), disp.__class__)
del disp
assert display_ref() is None


@pytest.mark.filterwarnings("ignore:A Bunch will be returned")
Expand Down
43 changes: 43 additions & 0 deletions sklearn/utils/_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import weakref


def _check_axes_has_been_used(ax):
"""Return true if the axes has been used"""
used_attrs = ['lines', 'patches', 'texts', 'tables', 'artists',
'tables', 'images']
msg = "The ax was already used in a matplotlib plot function"
if any(getattr(ax, attr, None) for attr in used_attrs):
raise ValueError(msg)


_SKLEARN_AX_DISP_OBJ_REF_KEY = "_sklearn_display_object_ref"


def _check_axes_has_display_object(display_obj, ax):
"""Check if axes has a weak ref to a display_obj

If the weak ref does not exist or points to None, the axes will be assigned
a reference to the passed in display_obj.

Used when the display object needs to use the axes to define a space to
create multiple plots on the axes.
"""
if not hasattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY):
setattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY, weakref.ref(display_obj))
ax.set_axis_off()
return display_obj

ax_display_obj = getattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY)()

if ax_display_obj is None: # display obj was deleted
setattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY, weakref.ref(display_obj))
ax.set_axis_off()
return display_obj
elif not isinstance(ax_display_obj, display_obj.__class__):
raise ValueError("The ax was already used by another "
"display object which is not an "
"instance of {}".format(
display_obj.__class__.__name__))

# ax._sklearn_display_object is an instance of display_obj.__class__
return ax_display_obj
65 changes: 65 additions & 0 deletions sklearn/utils/tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import weakref
import gc

import pytest

from sklearn.utils._plot import _check_axes_has_been_used
from sklearn.utils._plot import _check_axes_has_display_object
from sklearn.utils._plot import _SKLEARN_AX_DISP_OBJ_REF_KEY


def test_axes_has_been_used(pyplot):
fig, ax = pyplot.subplots()

msg = "The ax was already used in a matplotlib plot function"
_check_axes_has_been_used(ax) # no error

ax.plot([0, 1, 2], [1, 2, 3])
with pytest.raises(ValueError, match=msg):
_check_axes_has_been_used(ax)

ax.clear()
_check_axes_has_been_used(ax) # no error

ax.imshow([[0, 1, 2]])
with pytest.raises(ValueError, match=msg):
_check_axes_has_been_used(ax)


def test_check_axes_has_display_object(pyplot):
fig, ax = pyplot.subplots()

class MockDisplay1:
pass

display_1 = MockDisplay1()

# first call sets display weakref in ax
assert _check_axes_has_display_object(display_1, ax) == display_1
assert hasattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY)
display_1_ref = getattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY)
assert isinstance(display_1_ref, weakref.ref)
assert display_1_ref() == display_1

class MockDisplay2:
pass

display_2 = MockDisplay2()

# errors because ax already has a reference to a display object
msg = ("The ax was already used by another display object which is not "
"an instance of MockDisplay2")

with pytest.raises(ValueError, match=msg):
_check_axes_has_display_object(display_2, ax)

# deleting the display_1 object
del display_1
gc.collect() # python 3.5 needs to gc before the weak reference is removed
assert display_1_ref() is None

# With the display_1 deleted, a new display object can be added to ax
assert _check_axes_has_display_object(display_2, ax) == display_2
display_2_ref = getattr(ax, _SKLEARN_AX_DISP_OBJ_REF_KEY)
assert isinstance(display_2_ref, weakref.ref)
assert display_2_ref() == display_2