Skip to content

Commit 35b875c

Browse files
authored
Merge pull request #13738 from sasoripathos/fix10788
Fix TypeError when plotting stacked bar chart with decimal
2 parents a4d82fe + c56c175 commit 35b875c

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed

lib/matplotlib/axis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,8 +1565,8 @@ def have_units(self):
15651565
return self.converter is not None or self.units is not None
15661566

15671567
def convert_units(self, x):
1568-
# If x is already a number, doesn't need converting
1569-
if munits.ConversionInterface.is_numlike(x):
1568+
# If x is natively supported by Matplotlib, doesn't need converting
1569+
if munits.ConversionInterface.is_natively_supported(x):
15701570
return x
15711571

15721572
if self.converter is None:

lib/matplotlib/tests/test_axes.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from numpy import ma
1313
from cycler import cycler
14+
from decimal import Decimal
1415
import pytest
1516

1617
import warnings
@@ -1477,6 +1478,62 @@ def test_bar_tick_label_multiple_old_alignment():
14771478
align='center')
14781479

14791480

1481+
@check_figures_equal(extensions=["png"])
1482+
def test_bar_decimal_center(fig_test, fig_ref):
1483+
ax = fig_test.subplots()
1484+
x0 = [1.5, 8.4, 5.3, 4.2]
1485+
y0 = [1.1, 2.2, 3.3, 4.4]
1486+
x = [Decimal(x) for x in x0]
1487+
y = [Decimal(y) for y in y0]
1488+
# Test image - vertical, align-center bar chart with Decimal() input
1489+
ax.bar(x, y, align='center')
1490+
# Reference image
1491+
ax = fig_ref.subplots()
1492+
ax.bar(x0, y0, align='center')
1493+
1494+
1495+
@check_figures_equal(extensions=["png"])
1496+
def test_barh_decimal_center(fig_test, fig_ref):
1497+
ax = fig_test.subplots()
1498+
x0 = [1.5, 8.4, 5.3, 4.2]
1499+
y0 = [1.1, 2.2, 3.3, 4.4]
1500+
x = [Decimal(x) for x in x0]
1501+
y = [Decimal(y) for y in y0]
1502+
# Test image - horizontal, align-center bar chart with Decimal() input
1503+
ax.barh(x, y, height=[0.5, 0.5, 1, 1], align='center')
1504+
# Reference image
1505+
ax = fig_ref.subplots()
1506+
ax.barh(x0, y0, height=[0.5, 0.5, 1, 1], align='center')
1507+
1508+
1509+
@check_figures_equal(extensions=["png"])
1510+
def test_bar_decimal_width(fig_test, fig_ref):
1511+
x = [1.5, 8.4, 5.3, 4.2]
1512+
y = [1.1, 2.2, 3.3, 4.4]
1513+
w0 = [0.7, 1.45, 1, 2]
1514+
w = [Decimal(i) for i in w0]
1515+
# Test image - vertical bar chart with Decimal() width
1516+
ax = fig_test.subplots()
1517+
ax.bar(x, y, width=w, align='center')
1518+
# Reference image
1519+
ax = fig_ref.subplots()
1520+
ax.bar(x, y, width=w0, align='center')
1521+
1522+
1523+
@check_figures_equal(extensions=["png"])
1524+
def test_barh_decimal_height(fig_test, fig_ref):
1525+
x = [1.5, 8.4, 5.3, 4.2]
1526+
y = [1.1, 2.2, 3.3, 4.4]
1527+
h0 = [0.7, 1.45, 1, 2]
1528+
h = [Decimal(i) for i in h0]
1529+
# Test image - horizontal bar chart with Decimal() height
1530+
ax = fig_test.subplots()
1531+
ax.barh(x, y, height=h, align='center')
1532+
# Reference image
1533+
ax = fig_ref.subplots()
1534+
ax.barh(x, y, height=h0, align='center')
1535+
1536+
14801537
def test_bar_color_none_alpha():
14811538
ax = plt.gca()
14821539
rects = ax.bar([1, 2], [2, 4], alpha=0.3, color='none', edgecolor='r')
@@ -1819,6 +1876,21 @@ def test_scatter_2D(self):
18191876
fig, ax = plt.subplots()
18201877
ax.scatter(x, y, c=z, s=200, edgecolors='face')
18211878

1879+
@check_figures_equal(extensions=["png"])
1880+
def test_scatter_decimal(self, fig_test, fig_ref):
1881+
x0 = np.array([1.5, 8.4, 5.3, 4.2])
1882+
y0 = np.array([1.1, 2.2, 3.3, 4.4])
1883+
x = np.array([Decimal(i) for i in x0])
1884+
y = np.array([Decimal(i) for i in y0])
1885+
c = ['r', 'y', 'b', 'lime']
1886+
s = [24, 15, 19, 29]
1887+
# Test image - scatter plot with Decimal() input
1888+
ax = fig_test.subplots()
1889+
ax.scatter(x, y, c=c, s=s)
1890+
# Reference image
1891+
ax = fig_ref.subplots()
1892+
ax.scatter(x0, y0, c=c, s=s)
1893+
18221894
def test_scatter_color(self):
18231895
# Try to catch cases where 'c' kwarg should have been used.
18241896
with pytest.raises(ValueError):
@@ -5965,6 +6037,18 @@ def test_plot_columns_cycle_deprecation():
59656037
plt.plot(np.zeros((2, 2)), np.zeros((2, 3)))
59666038

59676039

6040+
@check_figures_equal(extensions=["png"])
6041+
def test_plot_decimal(fig_test, fig_ref):
6042+
x0 = np.arange(-10, 10, 0.3)
6043+
y0 = [5.2 * x ** 3 - 2.1 * x ** 2 + 7.34 * x + 4.5 for x in x0]
6044+
x = [Decimal(i) for i in x0]
6045+
y = [Decimal(i) for i in y0]
6046+
# Test image - line plot with Decimal input
6047+
fig_test.subplots().plot(x, y)
6048+
# Reference image
6049+
fig_ref.subplots().plot(x0, y0)
6050+
6051+
59686052
# pdf and svg tests fail using travis' old versions of gs and inkscape.
59696053
@check_figures_equal(extensions=["png"])
59706054
def test_markerfacecolor_none_alpha(fig_test, fig_ref):

lib/matplotlib/units.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def default_units(x, axis):
4545
from numbers import Number
4646

4747
import numpy as np
48+
from numpy import ma
49+
from decimal import Decimal
4850

4951
from matplotlib import cbook
5052

@@ -132,6 +134,58 @@ def is_numlike(x):
132134
else:
133135
return isinstance(x, Number)
134136

137+
@staticmethod
138+
def is_natively_supported(x):
139+
"""
140+
Return whether *x* is of a type that Matplotlib natively supports or
141+
*x* is array of objects of such types.
142+
"""
143+
# Matplotlib natively supports all number types except Decimal
144+
if np.iterable(x):
145+
# Assume lists are homogeneous as other functions in unit system
146+
for thisx in x:
147+
return (isinstance(thisx, Number) and
148+
not isinstance(thisx, Decimal))
149+
else:
150+
return isinstance(x, Number) and not isinstance(x, Decimal)
151+
152+
153+
class DecimalConverter(ConversionInterface):
154+
"""
155+
Converter for decimal.Decimal data to float.
156+
"""
157+
@staticmethod
158+
def convert(value, unit, axis):
159+
"""
160+
Convert Decimals to floats.
161+
162+
The *unit* and *axis* arguments are not used.
163+
164+
Parameters
165+
----------
166+
value : decimal.Decimal or iterable
167+
Decimal or list of Decimal need to be converted
168+
"""
169+
# If value is a Decimal
170+
if isinstance(value, Decimal):
171+
return np.float(value)
172+
else:
173+
# assume x is a list of Decimal
174+
converter = np.asarray
175+
if isinstance(value, ma.MaskedArray):
176+
converter = ma.asarray
177+
return converter(value, dtype=np.float)
178+
179+
@staticmethod
180+
def axisinfo(unit, axis):
181+
# Since Decimal is a kind of Number, don't need specific axisinfo.
182+
return AxisInfo()
183+
184+
@staticmethod
185+
def default_units(x, axis):
186+
# Return None since Decimal is a kind of Number.
187+
return None
188+
135189

136190
class Registry(dict):
137191
"""Register types with conversion interface."""
@@ -164,3 +218,4 @@ def get_converter(self, x):
164218

165219

166220
registry = Registry()
221+
registry[Decimal] = DecimalConverter()

0 commit comments

Comments
 (0)