From f263f7d54cca15f8fa3942e3482d432a9bad51ab Mon Sep 17 00:00:00 2001 From: Rob Righter Date: Wed, 22 Feb 2023 00:05:01 -0500 Subject: [PATCH] Bugfix for loc legend validation Co-authored-by: John Paul Jepko --- lib/matplotlib/legend.py | 17 +++++++ lib/matplotlib/tests/test_legend.py | 76 +++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/lib/matplotlib/legend.py b/lib/matplotlib/legend.py index ff6abdb95844..b17a239d8df5 100644 --- a/lib/matplotlib/legend.py +++ b/lib/matplotlib/legend.py @@ -23,6 +23,7 @@ import itertools import logging +import numbers import time import numpy as np @@ -517,6 +518,9 @@ def val_or_rc(val, rc_name): if not self.isaxes and loc in [0, 'best']: loc = 'upper right' + type_err_message = ("loc must be string, coordinate tuple, or" + f" an integer 0-10, not {loc!r}") + # handle outside legends: self._outside_loc = None if isinstance(loc, str): @@ -535,6 +539,19 @@ def val_or_rc(val, rc_name): loc = locs[0] + ' ' + locs[1] # check that loc is in acceptable strings loc = _api.check_getitem(self.codes, loc=loc) + elif np.iterable(loc): + # coerce iterable into tuple + loc = tuple(loc) + # validate the tuple represents Real coordinates + if len(loc) != 2 or not all(isinstance(e, numbers.Real) for e in loc): + raise ValueError(type_err_message) + elif isinstance(loc, int): + # validate the integer represents a string numeric value + if loc < 0 or loc > 10: + raise ValueError(type_err_message) + else: + # all other cases are invalid values of loc + raise ValueError(type_err_message) if self.isaxes and self._outside_loc: raise ValueError( diff --git a/lib/matplotlib/tests/test_legend.py b/lib/matplotlib/tests/test_legend.py index a8d7fd107d8b..74d054af6b05 100644 --- a/lib/matplotlib/tests/test_legend.py +++ b/lib/matplotlib/tests/test_legend.py @@ -1219,3 +1219,79 @@ def test_ncol_ncols(fig_test, fig_ref): ncols = 3 fig_test.legend(strings, ncol=ncols) fig_ref.legend(strings, ncols=ncols) + + +def test_loc_invalid_tuple_exception(): + # check that exception is raised if the loc arg + # of legend is not a 2-tuple of numbers + fig, ax = plt.subplots() + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not \\(1.1,\\)')): + ax.legend(loc=(1.1, )) + + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not \\(0.481, 0.4227, 0.4523\\)')): + ax.legend(loc=(0.481, 0.4227, 0.4523)) + + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not \\(0.481, \'go blue\'\\)')): + ax.legend(loc=(0.481, "go blue")) + + +def test_loc_valid_tuple(): + fig, ax = plt.subplots() + ax.legend(loc=(0.481, 0.442)) + ax.legend(loc=(1, 2)) + + +def test_loc_valid_list(): + fig, ax = plt.subplots() + ax.legend(loc=[0.481, 0.442]) + ax.legend(loc=[1, 2]) + + +def test_loc_invalid_list_exception(): + fig, ax = plt.subplots() + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not \\[1.1, 2.2, 3.3\\]')): + ax.legend(loc=[1.1, 2.2, 3.3]) + + +def test_loc_invalid_type(): + fig, ax = plt.subplots() + with pytest.raises(ValueError, match=("loc must be string, coordinate " + "tuple, or an integer 0-10, not {'not': True}")): + ax.legend(loc={'not': True}) + + +def test_loc_validation_numeric_value(): + fig, ax = plt.subplots() + ax.legend(loc=0) + ax.legend(loc=1) + ax.legend(loc=5) + ax.legend(loc=10) + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not 11')): + ax.legend(loc=11) + + with pytest.raises(ValueError, match=('loc must be string, coordinate ' + 'tuple, or an integer 0-10, not -1')): + ax.legend(loc=-1) + + +def test_loc_validation_string_value(): + fig, ax = plt.subplots() + ax.legend(loc='best') + ax.legend(loc='upper right') + ax.legend(loc='best') + ax.legend(loc='upper right') + ax.legend(loc='upper left') + ax.legend(loc='lower left') + ax.legend(loc='lower right') + ax.legend(loc='right') + ax.legend(loc='center left') + ax.legend(loc='center right') + ax.legend(loc='lower center') + ax.legend(loc='upper center') + with pytest.raises(ValueError, match="'wrong' is not a valid value for"): + ax.legend(loc='wrong')