Skip to content

Commit 385387b

Browse files
committed
RR - bugfix for loc legend validation
Co-authored-by: John Paul Jepko <jpjepko@users.noreply.github.com>
1 parent 70e0ba8 commit 385387b

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

lib/matplotlib/legend.py

+19
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import itertools
2525
import logging
26+
import numbers
2627
import time
2728

2829
import numpy as np
@@ -535,6 +536,24 @@ def val_or_rc(val, rc_name):
535536
loc = locs[0] + ' ' + locs[1]
536537
# check that loc is in acceptable strings
537538
loc = _api.check_getitem(self.codes, loc=loc)
539+
elif isinstance(loc, tuple):
540+
# validate the tuple represents Real coordinates
541+
if len(loc) != 2 or not (all(isinstance(e, numbers.Real)
542+
for e in loc)):
543+
raise ValueError(
544+
"loc must be string, coordinate tuple or," +
545+
f" an integer 0-10, not {loc!r}")
546+
elif isinstance(loc, int):
547+
# validate the integer represents a string numeric value
548+
if (loc < 0) or (loc > 10):
549+
raise ValueError(
550+
"loc must be string, coordinate tuple or," +
551+
f" an integer 0-10, not {loc!r}")
552+
else:
553+
# all other cases are invalid values of loc
554+
raise ValueError(
555+
"loc must be string, coordinate tuple or," +
556+
f" an integer 0-10, not {loc!r}")
538557

539558
if self.isaxes and self._outside_loc:
540559
raise ValueError(

lib/matplotlib/tests/test_legend.py

+68
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,71 @@ def test_ncol_ncols(fig_test, fig_ref):
12191219
ncols = 3
12201220
fig_test.legend(strings, ncol=ncols)
12211221
fig_ref.legend(strings, ncols=ncols)
1222+
1223+
1224+
def test_loc_invalid_tuple_exception():
1225+
# check that exception is raised if the loc arg
1226+
# of legend is not a 2-tuple of numbers
1227+
fig, ax = plt.subplots()
1228+
with pytest.raises(ValueError) as errorinfo:
1229+
ax.legend(loc=(1.1, ))
1230+
s = ("loc must be string, coordinate tuple or, an integer 0-10, not "
1231+
"(1.1,)")
1232+
assert s == str(errorinfo.value)
1233+
1234+
with pytest.raises(ValueError) as errorinfo:
1235+
ax.legend(loc=(0.481, 0.4227, 0.4523))
1236+
s = ("loc must be string, coordinate tuple or, an integer 0-10, not "
1237+
"(0.481, 0.4227, 0.4523)")
1238+
assert s == str(errorinfo.value)
1239+
1240+
with pytest.raises(ValueError) as errorinfo:
1241+
ax.legend(loc=(0.481, "go blue"))
1242+
s = ("loc must be string, coordinate tuple or, an integer 0-10, not "
1243+
"(0.481, 'go blue')")
1244+
assert s == str(errorinfo.value)
1245+
1246+
1247+
def test_loc_valid_tuple():
1248+
fig, ax = plt.subplots()
1249+
ax.legend(loc=(0.481, 0.442))
1250+
ax.legend(loc=(1, 2))
1251+
1252+
1253+
def test_loc_validation_string_numeric_value():
1254+
fig, ax = plt.subplots()
1255+
ax.legend(loc=0)
1256+
ax.legend(loc=1)
1257+
ax.legend(loc=5)
1258+
ax.legend(loc=10)
1259+
with pytest.raises(ValueError) as errorinfo:
1260+
ax.legend(loc=11)
1261+
s = "loc must be string, coordinate tuple or, an integer 0-10, not 11"
1262+
assert s == str(errorinfo.value)
1263+
1264+
with pytest.raises(ValueError) as errorinfo:
1265+
ax.legend(loc=-1)
1266+
s = "loc must be string, coordinate tuple or, an integer 0-10, not -1"
1267+
assert s == str(errorinfo.value)
1268+
1269+
1270+
def test_loc_validation_string_value():
1271+
fig, ax = plt.subplots()
1272+
ax.legend(loc='best')
1273+
ax.legend(loc='upper right')
1274+
ax.legend(loc='best')
1275+
ax.legend(loc='upper right')
1276+
ax.legend(loc='upper left')
1277+
ax.legend(loc='lower left')
1278+
ax.legend(loc='lower right')
1279+
ax.legend(loc='right')
1280+
ax.legend(loc='center left')
1281+
ax.legend(loc='center right')
1282+
ax.legend(loc='lower center')
1283+
ax.legend(loc='upper center')
1284+
with pytest.raises(ValueError) as errorinfo:
1285+
ax.legend(loc='wrong')
1286+
s = ("'wrong' is not a valid value for loc; supported values are 'best',"
1287+
" 'upper right', 'upper left', 'lower left', 'lower right', 'right',"
1288+
" 'center left', 'center right', 'lower center', 'upper center', 'center'")
1289+
assert s == str(errorinfo.value)

0 commit comments

Comments
 (0)