diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index db45e04410e6..d5a8d08ec8ee 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -163,6 +163,38 @@ class _AxesStack(object): """Lightweight stack that tracks Axes in a Figure. """ + @staticmethod + def __key_compare(k1, k2): + k1_args, k1_kwargs = k1 + k2_args, k2_kwargs = k2 + + def np_safe_eq(left, right): + out = (left == right) + try: + out = bool(out) + except ValueError: + out = out.all() + try: + out &= len(left) == len(right) + except TypeError: + out = False + return out + + for a, b in zip(k1_args, k2_args): + test = np_safe_eq(a, b) + if not test: + return False + if set(k1_kwargs) != set(k2_kwargs): + return False + + for k in k1_kwargs: + a = k1_kwargs[k] + b = k2_kwargs[k] + test = np_safe_eq(a, b) + if not test: + return False + return True + def __init__(self): # We maintain a list of (creation_index, key, axes) tuples. # We do not use an OrderedDict because 1. the keys may not be hashable @@ -179,7 +211,10 @@ def as_list(self): def get(self, key): """Find the axes corresponding to a key; defaults to `None`. """ - return next((ax for _, k, ax in self._items if k == key), None) + return next((ax + for _, k, ax in self._items + if self.__key_compare(k, key)), + None) def current_key_axes(self): """Return the topmost `(key, axes)` pair, or `(None, None)` if empty. diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 5f2d3b1b2a82..d6df6fc78dd7 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -319,3 +319,11 @@ def test_subplots_shareax_loglabels(): for ax in ax_arr[:, 0]: assert 0 < len(ax.yaxis.get_ticklabels(which='both')) + + +def test_axes_add_np_behavior(): + ax1 = plt.axes(np.array([.1, .1, .8, .8])) + ax2 = plt.axes(np.array([.1, .1, .8, .8])) + # in the future this test will need to be changed to not assert + # that the axes are equal, but still check that this does not blowup. + assert ax1 is ax2