From 8231ac7ef911a0cbda7d587df7de9a578c7206bc Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 13 Aug 2017 22:42:32 -0400 Subject: [PATCH 1/2] FIX: the new _AxesStack with np.array as input This adds a layer checking in the case where a numpy array is passed as a value (either in args or kwargs). --- lib/matplotlib/figure.py | 37 ++++++++++++++++++++++++++++- lib/matplotlib/tests/test_figure.py | 8 +++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index db45e04410e6..d5a8d08ec8ee 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -163,6 +163,38 @@ class _AxesStack(object): """Lightweight stack that tracks Axes in a Figure. """ + @staticmethod + def __key_compare(k1, k2): + k1_args, k1_kwargs = k1 + k2_args, k2_kwargs = k2 + + def np_safe_eq(left, right): + out = (left == right) + try: + out = bool(out) + except ValueError: + out = out.all() + try: + out &= len(left) == len(right) + except TypeError: + out = False + return out + + for a, b in zip(k1_args, k2_args): + test = np_safe_eq(a, b) + if not test: + return False + if set(k1_kwargs) != set(k2_kwargs): + return False + + for k in k1_kwargs: + a = k1_kwargs[k] + b = k2_kwargs[k] + test = np_safe_eq(a, b) + if not test: + return False + return True + def __init__(self): # We maintain a list of (creation_index, key, axes) tuples. # We do not use an OrderedDict because 1. the keys may not be hashable @@ -179,7 +211,10 @@ def as_list(self): def get(self, key): """Find the axes corresponding to a key; defaults to `None`. """ - return next((ax for _, k, ax in self._items if k == key), None) + return next((ax + for _, k, ax in self._items + if self.__key_compare(k, key)), + None) def current_key_axes(self): """Return the topmost `(key, axes)` pair, or `(None, None)` if empty. diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 5f2d3b1b2a82..a46ac5366049 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -319,3 +319,11 @@ def test_subplots_shareax_loglabels(): for ax in ax_arr[:, 0]: assert 0 < len(ax.yaxis.get_ticklabels(which='both')) + + +def test_axes_add_np_behavior(): + ax1 = plt.axes(plt.axes(np.array([.1, .1, .8, .8]))) + ax2 = plt.axes(plt.axes(np.array([.1, .1, .8, .8]))) + # in the future this test will need to be changed to not assert + # that the axes are equal, but still check that this does not blowup. + assert ax1 is ax2 From 912063b3e53c5dd44b45121e65744a3ed09f7c60 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Mon, 14 Aug 2017 15:21:57 -0400 Subject: [PATCH 2/2] TST: simplify test --- lib/matplotlib/tests/test_figure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index a46ac5366049..d6df6fc78dd7 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -322,8 +322,8 @@ def test_subplots_shareax_loglabels(): def test_axes_add_np_behavior(): - ax1 = plt.axes(plt.axes(np.array([.1, .1, .8, .8]))) - ax2 = plt.axes(plt.axes(np.array([.1, .1, .8, .8]))) + ax1 = plt.axes(np.array([.1, .1, .8, .8])) + ax2 = plt.axes(np.array([.1, .1, .8, .8])) # in the future this test will need to be changed to not assert # that the axes are equal, but still check that this does not blowup. assert ax1 is ax2