diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index a1ca001fe587..8b64c7d6cb98 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -4,15 +4,35 @@ """ from __future__ import (absolute_import, division, print_function, unicode_literals) - import six import numpy as np -import matplotlib.cbook as cbook import matplotlib.units as units import matplotlib.ticker as ticker +# np 1.6/1.7 support +from distutils.version import LooseVersion +import collections + + +if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): + def shim_array(data): + return np.array(data, dtype=np.unicode) +else: + def shim_array(data): + if (isinstance(data, six.string_types) or + not isinstance(data, collections.Iterable)): + data = [data] + try: + data = [str(d) for d in data] + except UnicodeEncodeError: + # this yields gibberish but unicode text doesn't + # render under numpy1.6 anyway + data = [d.encode('utf-8', 'ignore').decode('utf-8') + for d in data] + return np.array(data, dtype=np.unicode) + class StrCategoryConverter(units.ConversionInterface): @staticmethod @@ -25,7 +45,8 @@ def convert(value, unit, axis): if isinstance(value, six.string_types): return vmap[value] - vals = np.array(value, dtype=np.unicode) + vals = shim_array(value) + for lab, loc in vmap.items(): vals[vals == lab] = loc @@ -81,8 +102,7 @@ def update(self, new_data): self._set_seq_locs(new_data, value) def _set_seq_locs(self, data, value): - strdata = np.array(data, dtype=np.unicode) - # np.unique makes dateframes work + strdata = shim_array(data) new_s = [d for d in np.unique(strdata) if d not in self.seq] for ns in new_s: self.seq.append(ns) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 83847e3150bb..6e5c43d76fb9 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -3,8 +3,6 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -from distutils.version import LooseVersion - import pytest import numpy as np @@ -14,11 +12,6 @@ import unittest -needs_new_numpy = pytest.mark.xfail( - LooseVersion(np.__version__) < LooseVersion('1.8.0'), - reason='NumPy < 1.8.0 is broken.') - - class TestUnitData(object): testdata = [("hello world", ["hello world"], [0]), ("Здравствуйте мир", ["Здравствуйте мир"], [0]), @@ -28,14 +21,12 @@ class TestUnitData(object): ids = ["single", "unicode", "mixed"] - @needs_new_numpy @pytest.mark.parametrize("data, seq, locs", testdata, ids=ids) def test_unit(self, data, seq, locs): act = cat.UnitData(data) assert act.seq == seq assert act.locs == locs - @needs_new_numpy def test_update_map(self): data = ['a', 'd'] oseq = ['a', 'd'] @@ -87,7 +78,6 @@ class TestStrCategoryConverter(object): def mock_axis(self, request): self.cc = cat.StrCategoryConverter() - @needs_new_numpy @pytest.mark.parametrize("data, unitmap, exp", testdata, ids=ids) def test_convert(self, data, unitmap, exp): MUD = MockUnitData(unitmap)