diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 47645f94c1fa..a4541503d9a0 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -509,6 +509,7 @@ def legend(self, *args, **kwargs): raise TypeError('Invalid arguments to legend.') self.legend_ = mlegend.Legend(self, handles, labels, **kwargs) + self.legend_._remove_method = lambda h: setattr(self, 'legend_', None) return self.legend_ def text(self, x, y, s, fontdict=None, diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 34f529ecb044..c6730642b3c5 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1166,6 +1166,7 @@ def legend(self, handles, labels, *args, **kwargs): """ l = Legend(self, handles, labels, *args, **kwargs) self.legends.append(l) + l._remove_method = lambda h: self.legends.remove(h) return l @docstring.dedent_interpd diff --git a/lib/matplotlib/tests/test_legend.py b/lib/matplotlib/tests/test_legend.py index 15dadeb6c75f..d4eeae2e34f1 100644 --- a/lib/matplotlib/tests/test_legend.py +++ b/lib/matplotlib/tests/test_legend.py @@ -80,21 +80,21 @@ def test_framealpha(): plt.legend(framealpha=0.5) -@image_comparison(baseline_images=['scatter_rc3','scatter_rc1'], remove_text=True) +@image_comparison(baseline_images=['scatter_rc3', 'scatter_rc1'], remove_text=True) def test_rc(): # using subplot triggers some offsetbox functionality untested elsewhere fig = plt.figure() - ax = plt.subplot(121) + ax = plt.subplot(121) ax.scatter(list(xrange(10)), list(xrange(10, 0, -1)), label='three') ax.legend(loc="center left", bbox_to_anchor=[1.0, 0.5], - title="My legend") + title="My legend") mpl.rcParams['legend.scatterpoints'] = 1 fig = plt.figure() - ax = plt.subplot(121) + ax = plt.subplot(121) ax.scatter(list(xrange(10)), list(xrange(10, 0, -1)), label='one') ax.legend(loc="center left", bbox_to_anchor=[1.0, 0.5], - title="My legend") + title="My legend") @image_comparison(baseline_images=['legend_expand'], remove_text=True) @@ -113,6 +113,19 @@ def test_legend_expand(): ax.legend(loc=3, mode=mode, ncol=2) +@cleanup +def test_legend_remove(): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + lines = ax.plot(range(10)) + leg = fig.legend(lines, "test") + leg.remove() + assert_equal(fig.legends, []) + leg = ax.legend("test") + leg.remove() + assert ax.get_legend() is None + + class TestLegendFunction(object): # Tests the legend function on the Axes and pyplot. @@ -154,8 +167,8 @@ def __call__(self, legend, orig_handle, fontsize, handlebox): handler_map={None: AnyObjectHandler()}) warn.assert_called_with(u'Legend handers must now implement a ' - '"legend_artist" method rather than ' - 'being a callable.', + '"legend_artist" method rather than ' + 'being a callable.', MatplotlibDeprecationWarning, stacklevel=1)