From 0b7fcba17274d73464425d8fc80ec339959c0c63 Mon Sep 17 00:00:00 2001 From: John Paul Jepko Date: Tue, 6 Dec 2022 16:59:30 -0500 Subject: [PATCH] #24605 add check to validate loc tuples --- lib/matplotlib/legend.py | 9 +++++++++ lib/matplotlib/tests/test_legend.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/lib/matplotlib/legend.py b/lib/matplotlib/legend.py index d0590824ad84..a64d806ae6e0 100644 --- a/lib/matplotlib/legend.py +++ b/lib/matplotlib/legend.py @@ -23,6 +23,7 @@ import itertools import logging +import numbers import time import numpy as np @@ -472,6 +473,14 @@ def val_or_rc(val, rc_name): loc = 'upper right' if isinstance(loc, str): loc = _api.check_getitem(self.codes, loc=loc) + elif isinstance(loc, tuple): + if len(loc) != 2 or not (all(isinstance(e, numbers.Real) + for e in loc)): + raise ValueError( + f"loc must be string or pair of numbers, not {loc!r}") + else: + raise ValueError( + f"loc must be string or pair of numbers, not {loc!r}") if not self.isaxes and loc == 0: raise ValueError( "Automatic legend placement (loc='best') not implemented for " diff --git a/lib/matplotlib/tests/test_legend.py b/lib/matplotlib/tests/test_legend.py index 6660b91ecdd9..4fd34e4321a8 100644 --- a/lib/matplotlib/tests/test_legend.py +++ b/lib/matplotlib/tests/test_legend.py @@ -1090,3 +1090,27 @@ def test_ncol_ncols(fig_test, fig_ref): ncols = 3 fig_test.legend(strings, ncol=ncols) fig_ref.legend(strings, ncols=ncols) + + +def test_loc_invalid_tuple_exception(): + # check that exception is raised if the loc arg + # of legend is not a 2-tuple of numbers + fig, ax = plt.subplots() + with pytest.raises(ValueError, + match="loc must be string or pair of numbers, not " + "\\(1.1\\,\\)"): # regex escape special chars + ax.legend(loc=(1.1, )) + with pytest.raises(ValueError, + match="loc must be string or pair of numbers, not " + "\\(0.481\\, 0.4227\\, 0.4523\\)"): + ax.legend(loc=(0.481, 0.4227, 0.4523)) + with pytest.raises(ValueError, + match="loc must be string or pair of numbers, not " + "\\(0.481\\, \\'go blue\\'\\)"): + ax.legend(loc=(0.481, "go blue")) + + +def test_loc_valid_tuple(): + fig, ax = plt.subplots() + ax.legend(loc=(0.481, 0.442)) + ax.legend(loc=(1, 2))