Skip to content

Commit 586e585

Browse files
committed
Bugfix for loc legend validation
Co-authored-by: John Paul Jepko <jpjepko@users.noreply.github.com>
1 parent 70e0ba8 commit 586e585

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

lib/matplotlib/legend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import itertools
2525
import logging
26+
import numbers
2627
import time
2728

2829
import numpy as np
@@ -517,8 +518,12 @@ def val_or_rc(val, rc_name):
517518
if not self.isaxes and loc in [0, 'best']:
518519
loc = 'upper right'
519520

521+
type_err_message = ("loc must be string, coordinate tuple, or"
522+
f" an integer 0-10, not {loc!r}")
523+
520524
# handle outside legends:
521525
self._outside_loc = None
526+
522527
if isinstance(loc, str):
523528
if loc.split()[0] == 'outside':
524529
# strip outside:
@@ -535,6 +540,19 @@ def val_or_rc(val, rc_name):
535540
loc = locs[0] + ' ' + locs[1]
536541
# check that loc is in acceptable strings
537542
loc = _api.check_getitem(self.codes, loc=loc)
543+
elif np.iterable(loc):
544+
# coerce iterable into tuple
545+
loc = tuple(loc)
546+
# validate the tuple represents Real coordinates
547+
if len(loc) != 2 or not all(isinstance(e, numbers.Real) for e in loc):
548+
raise ValueError(type_err_message)
549+
elif isinstance(loc, int):
550+
# validate the integer represents a string numeric value
551+
if loc < 0 or loc > 10:
552+
raise ValueError(type_err_message)
553+
else:
554+
# all other cases are invalid values of loc
555+
raise ValueError(type_err_message)
538556

539557
if self.isaxes and self._outside_loc:
540558
raise ValueError(

lib/matplotlib/tests/test_legend.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,79 @@ def test_ncol_ncols(fig_test, fig_ref):
12191219
ncols = 3
12201220
fig_test.legend(strings, ncol=ncols)
12211221
fig_ref.legend(strings, ncols=ncols)
1222+
1223+
1224+
def test_loc_invalid_tuple_exception():
1225+
# check that exception is raised if the loc arg
1226+
# of legend is not a 2-tuple of numbers
1227+
fig, ax = plt.subplots()
1228+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1229+
'tuple, or an integer 0-10, not \\(1.1,\\)')):
1230+
ax.legend(loc=(1.1, ))
1231+
1232+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1233+
'tuple, or an integer 0-10, not \\(0.481, 0.4227, 0.4523\\)')):
1234+
ax.legend(loc=(0.481, 0.4227, 0.4523))
1235+
1236+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1237+
'tuple, or an integer 0-10, not \\(0.481, \'go blue\'\\)')):
1238+
ax.legend(loc=(0.481, "go blue"))
1239+
1240+
1241+
def test_loc_valid_tuple():
1242+
fig, ax = plt.subplots()
1243+
ax.legend(loc=(0.481, 0.442))
1244+
ax.legend(loc=(1, 2))
1245+
1246+
1247+
def test_loc_valid_list():
1248+
fig, ax = plt.subplots()
1249+
ax.legend(loc=[0.481, 0.442])
1250+
ax.legend(loc=[1, 2])
1251+
1252+
1253+
def test_loc_invalid_list_exception():
1254+
fig, ax = plt.subplots()
1255+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1256+
'tuple, or an integer 0-10, not \\[1.1, 2.2, 3.3\\]')):
1257+
ax.legend(loc=[1.1, 2.2, 3.3])
1258+
1259+
1260+
def test_loc_invalid_type():
1261+
fig, ax = plt.subplots()
1262+
with pytest.raises(ValueError, match=("loc must be string, coordinate "
1263+
"tuple, or an integer 0-10, not {'not': True}")):
1264+
ax.legend(loc={'not': True})
1265+
1266+
1267+
def test_loc_validation_numeric_value():
1268+
fig, ax = plt.subplots()
1269+
ax.legend(loc=0)
1270+
ax.legend(loc=1)
1271+
ax.legend(loc=5)
1272+
ax.legend(loc=10)
1273+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1274+
'tuple, or an integer 0-10, not 11')):
1275+
ax.legend(loc=11)
1276+
1277+
with pytest.raises(ValueError, match=('loc must be string, coordinate '
1278+
'tuple, or an integer 0-10, not -1')):
1279+
ax.legend(loc=-1)
1280+
1281+
1282+
def test_loc_validation_string_value():
1283+
fig, ax = plt.subplots()
1284+
ax.legend(loc='best')
1285+
ax.legend(loc='upper right')
1286+
ax.legend(loc='best')
1287+
ax.legend(loc='upper right')
1288+
ax.legend(loc='upper left')
1289+
ax.legend(loc='lower left')
1290+
ax.legend(loc='lower right')
1291+
ax.legend(loc='right')
1292+
ax.legend(loc='center left')
1293+
ax.legend(loc='center right')
1294+
ax.legend(loc='lower center')
1295+
ax.legend(loc='upper center')
1296+
with pytest.raises(ValueError, match="'wrong' is not a valid value for"):
1297+
ax.legend(loc='wrong')

0 commit comments

Comments
 (0)