diff --git a/doc/users/next_whats_new/2019-02-27-AL.rst b/doc/users/next_whats_new/2019-02-27-AL.rst new file mode 100644 index 000000000000..f6d1779f150a --- /dev/null +++ b/doc/users/next_whats_new/2019-02-27-AL.rst @@ -0,0 +1,5 @@ +Unit converters now handle instances of subclasses +`````````````````````````````````````````````````` + +Unit converters now also handle instances of subclasses of the class they have +been registered for. diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 58b225aa3482..e9ce737385a3 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -1,10 +1,11 @@ +from datetime import datetime +import platform from unittest.mock import MagicMock import matplotlib.pyplot as plt -from matplotlib.testing.decorators import image_comparison +from matplotlib.testing.decorators import check_figures_equal, image_comparison import matplotlib.units as munits import numpy as np -import platform import pytest @@ -119,7 +120,6 @@ def test_empty_set_limits_with_units(quantity_converter): @image_comparison(['jpl_bar_units.png'], savefig_kwarg={'dpi': 120}, style='mpl20') def test_jpl_bar_units(): - from datetime import datetime import matplotlib.testing.jpl_units as units units.register() @@ -136,7 +136,6 @@ def test_jpl_bar_units(): @image_comparison(['jpl_barh_units.png'], savefig_kwarg={'dpi': 120}, style='mpl20') def test_jpl_barh_units(): - from datetime import datetime import matplotlib.testing.jpl_units as units units.register() @@ -164,3 +163,12 @@ def test_scatter_element0_masked(): fig, ax = plt.subplots() ax.scatter(times, y) fig.canvas.draw() + + +@check_figures_equal(extensions=["png"]) +def test_subclass(fig_test, fig_ref): + class subdate(datetime): + pass + + fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o") + fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o") diff --git a/lib/matplotlib/units.py b/lib/matplotlib/units.py index b5c9f724b44e..a6186a4eb59f 100644 --- a/lib/matplotlib/units.py +++ b/lib/matplotlib/units.py @@ -205,18 +205,20 @@ def get_converter(self, x): # If there are no elements in x, infer the units from its dtype if not x.size: return self.get_converter(np.array([0], dtype=x.dtype)) - try: # Look up in the cache. - return self[type(x)] - except KeyError: - try: # If cache lookup fails, look up based on first element... - first = cbook.safe_first_element(x) - except (TypeError, StopIteration): + for cls in type(x).__mro__: # Look up in the cache. + try: + return self[cls] + except KeyError: pass - else: - # ... and avoid infinite recursion for pathological iterables - # where indexing returns instances of the same iterable class. - if type(first) is not type(x): - return self.get_converter(first) + try: # If cache lookup fails, look up based on first element... + first = cbook.safe_first_element(x) + except (TypeError, StopIteration): + pass + else: + # ... and avoid infinite recursion for pathological iterables for + # which indexing returns instances of the same iterable class. + if type(first) is not type(x): + return self.get_converter(first) return None