From cb2b850d4589445d3d7319da824e7b5f4ab13c75 Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Thu, 2 May 2024 22:10:45 +0200 Subject: [PATCH 1/7] Allow the use of unions as match patterns --- Lib/test/test_patma.py | 34 ++++++++++++++++++++++++++++++++++ Python/ceval.c | 15 +++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 5d0857b059ea23..bf849c673d4ab5 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -4,6 +4,7 @@ import dis import enum import inspect +from re import I import sys import unittest from test import support @@ -2888,6 +2889,14 @@ class B(A): ... h = 1 self.assertEqual(h, 1) + def test_patma_union_type(self): + IntOrStr = int | str + x = 0 + match x: + case IntOrStr(): + x = 1 + self.assertEqual(x, 1) + class TestSyntaxErrors(unittest.TestCase): @@ -3370,6 +3379,31 @@ class A: w = 0 self.assertIsNone(w) + def test_union_type_postional_subpattern(self): + IntOrStr = int | str + x = 1 + w = None + with self.assertRaises(TypeError): + match x: + case IntOrStr(x): + w = 0 + self.assertEqual(x, 1) + self.assertIsNone(w) + + def test_union_type_keyword_subpattern(self): + @dataclasses.dataclass + class Point2: + x: int + y: int + EitherPoint = Point | Point2 + x = Point(x=1, y=2) + w = None + with self.assertRaises(TypeError): + match x: + case EitherPoint(x=1, y=2): + w = 0 + self.assertIsNone(w) + class TestValueErrors(unittest.TestCase): diff --git a/Python/ceval.c b/Python/ceval.c index 291e753dec0ce5..24df578f158d30 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -39,6 +39,7 @@ #include "pycore_template.h" // _PyTemplate_Build() #include "pycore_traceback.h" // _PyTraceBack_FromFrame #include "pycore_tuple.h" // _PyTuple_ITEMS() +#include "pycore_unionobject.h" // _PyUnion_Check() #include "pycore_uop_ids.h" // Uops #include "dictobject.h" @@ -725,8 +726,8 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { - if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class"; + if (!PyType_Check(type) && !_PyUnion_Check(type)) { + const char *e = "called match pattern must be a class or a union"; _PyErr_Format(tstate, PyExc_TypeError, e); return NULL; } @@ -735,6 +736,16 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, if (PyObject_IsInstance(subject, type) <= 0) { return NULL; } + // Subpatterns are not supported for union types: + if (_PyUnion_Check(type)) { + // Return error if any positional or keyword arguments are given: + if (nargs || PyTuple_GET_SIZE(kwargs)) { + const char *e = "union types do not support sub-patterns"; + _PyErr_Format(tstate, PyExc_TypeError, e); + return NULL; + } + return PyTuple_New(0); + } // So far so good: PyObject *seen = PySet_New(NULL); if (seen == NULL) { From 114f250a8adf73b7f8b382ef32ebf582c8536a22 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 6 May 2024 14:26:57 +0200 Subject: [PATCH 2/7] add match-case support for unions --- Lib/test/test_patma.py | 249 ++++++++++++++++++++++++++++++++++++----- Python/ceval.c | 34 +++--- 2 files changed, 240 insertions(+), 43 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index bf849c673d4ab5..c16f09a583fbba 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -4,7 +4,6 @@ import dis import enum import inspect -from re import I import sys import unittest from test import support @@ -16,6 +15,13 @@ class Point: y: int +@dataclasses.dataclass +class Point3D: + x: int + y: int + z: int + + class TestCompiler(unittest.TestCase): def test_refleaks(self): @@ -2891,11 +2897,81 @@ class B(A): ... def test_patma_union_type(self): IntOrStr = int | str - x = 0 - match x: + w = None + match 0: case IntOrStr(): - x = 1 - self.assertEqual(x, 1) + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_no_match(self): + StrOrBytes = str | bytes + w = None + match 0: + case StrOrBytes(): + w = 0 + self.assertIsNone(w) + + def test_union_type_positional_subpattern(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(y): + w = y + self.assertEqual(w, 0) + + def test_union_type_keyword_subpattern(self): + EitherPoint = Point | Point3D + p = Point(x=1, y=2) + w = None + match p: + case EitherPoint(x=1, y=2): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(IntOrStr(), IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_kwarg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(x=IntOrStr(), y=IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(StrOrBytes(), StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_patma_union_kwarg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(x=StrOrBytes(), y=StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_union_type_match_second_member(self): + EitherPoint = Point | Point3D + p = Point3D(x=1, y=2, z=3) + w = None + match p: + case EitherPoint(x=1, y=2, z=3): + w = 0 + self.assertEqual(w, 0) class TestSyntaxErrors(unittest.TestCase): @@ -3239,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self): pass """) + class TestTypeErrors(unittest.TestCase): + def test_generic_type(self): + t = list[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + + def test_legacy_generic_type(self): + from typing import List + t = List[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_accepts_positional_subpatterns_0(self): class Class: __match_args__ = () @@ -3350,6 +3446,124 @@ def test_class_pattern_not_type(self): w = 0 self.assertIsNone(w) + def test_class_or_union_not_specialform(self): + from typing import Literal + name = type(Literal).__name__ + msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case Literal(): + w = 0 + self.assertIsNone(w) + + def test_legacy_union_type(self): + from typing import Union + IntOrStr = Union[int, str] + name = type(IntOrStr).__name__ + msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case IntOrStr(): + w = 0 + self.assertIsNone(w) + + def test_expanded_union_mirrors_isinstance_success(self): + ListOfInt = list[int] + t = int | ListOfInt + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case int() | ListOfInt(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_expanded_union_mirrors_isinstance_failure(self): + ListOfInt = list[int] + t = ListOfInt | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case ListOfInt() | int(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_union_type_mirrors_isinstance_success(self): + t = int | list[int] + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_union_type_mirrors_isinstance_failure(self): + t = list[int] | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_generic_union_type(self): + from collections.abc import Sequence, Set + t = Sequence[str] | Set[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_regular_protocol(self): from typing import Protocol class P(Protocol): ... @@ -3379,31 +3593,6 @@ class A: w = 0 self.assertIsNone(w) - def test_union_type_postional_subpattern(self): - IntOrStr = int | str - x = 1 - w = None - with self.assertRaises(TypeError): - match x: - case IntOrStr(x): - w = 0 - self.assertEqual(x, 1) - self.assertIsNone(w) - - def test_union_type_keyword_subpattern(self): - @dataclasses.dataclass - class Point2: - x: int - y: int - EitherPoint = Point | Point2 - x = Point(x=1, y=2) - w = None - with self.assertRaises(TypeError): - match x: - case EitherPoint(x=1, y=2): - w = 0 - self.assertIsNone(w) - class TestValueErrors(unittest.TestCase): diff --git a/Python/ceval.c b/Python/ceval.c index 24df578f158d30..a5d6f37aa11beb 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -726,9 +726,27 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { - if (!PyType_Check(type) && !_PyUnion_Check(type)) { - const char *e = "called match pattern must be a class or a union"; - _PyErr_Format(tstate, PyExc_TypeError, e); + // Recurse on unions. + if (_PyUnion_Check(type)) { + // get union members + PyObject *members = _Py_union_args(type); + const Py_ssize_t n = PyTuple_GET_SIZE(members); + + // iterate over union members and return first match + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *member = PyTuple_GET_ITEM(members, i); + PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs); + // match found + if (attrs != NULL) { + return attrs; + } + } + // no match found + return NULL; + } + if (!PyType_Check(type)) { + const char *e = "called match pattern must be a class or types.UnionType (got %s)"; + _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } assert(PyTuple_CheckExact(kwargs)); @@ -736,16 +754,6 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, if (PyObject_IsInstance(subject, type) <= 0) { return NULL; } - // Subpatterns are not supported for union types: - if (_PyUnion_Check(type)) { - // Return error if any positional or keyword arguments are given: - if (nargs || PyTuple_GET_SIZE(kwargs)) { - const char *e = "union types do not support sub-patterns"; - _PyErr_Format(tstate, PyExc_TypeError, e); - return NULL; - } - return PyTuple_New(0); - } // So far so good: PyObject *seen = PySet_New(NULL); if (seen == NULL) { From 2f9aa386a31f71276dee415557fe92993c20b97b Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 11:35:48 +0200 Subject: [PATCH 3/7] Updated error string --- Lib/test/test_patma.py | 4 ++-- Python/ceval.c | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index c16f09a583fbba..52b5aa62d1f116 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3449,7 +3449,7 @@ def test_class_pattern_not_type(self): def test_class_or_union_not_specialform(self): from typing import Literal name = type(Literal).__name__ - msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: @@ -3461,7 +3461,7 @@ def test_legacy_union_type(self): from typing import Union IntOrStr = Union[int, str] name = type(IntOrStr).__name__ - msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)" + msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: diff --git a/Python/ceval.c b/Python/ceval.c index a5d6f37aa11beb..65d69ffe352ec5 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -745,7 +745,7 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, return NULL; } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class or types.UnionType (got %s)"; + const char *e = "called match pattern must be a class or typing.Union of classes (got %s)"; _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } From 4829f23480950b5c831acbdb32695fd307321dcc Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 11:38:15 +0200 Subject: [PATCH 4/7] changed test_legacy_union to test_typing_union. Changed behavior due to gh-105499 --- Lib/test/test_patma.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 52b5aa62d1f116..acbf6aafadcddd 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3457,17 +3457,14 @@ def test_class_or_union_not_specialform(self): w = 0 self.assertIsNone(w) - def test_legacy_union_type(self): + def test_typing_union(self): from typing import Union - IntOrStr = Union[int, str] - name = type(IntOrStr).__name__ - msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" - w = None - with self.assertRaisesRegex(TypeError, msg): - match 1: - case IntOrStr(): - w = 0 - self.assertIsNone(w) + IntOrStr = Union[int, str] # identical to int | str since gh-105499 + w = False + match 1: + case IntOrStr(): + w = True + self.assertIs(w, True) def test_expanded_union_mirrors_isinstance_success(self): ListOfInt = list[int] From 65f2a6437bbf69817b5d3ddd40ef98446a83c857 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 14:38:11 +0200 Subject: [PATCH 5/7] Update Python/ceval.c Co-authored-by: sobolevn --- Python/ceval.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Python/ceval.c b/Python/ceval.c index 65d69ffe352ec5..fd102ea4516aab 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -745,7 +745,7 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, return NULL; } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class or typing.Union of classes (got %s)"; + const char *e = "called match pattern must be a class or a union of classes (got %s)"; _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } From 3e12030d370c09c65640f867919bf9e504ed2007 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 15:13:40 +0200 Subject: [PATCH 6/7] fixed test to match new error message --- Lib/test/test_patma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index acbf6aafadcddd..cb6662d3eaaff1 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -3449,7 +3449,7 @@ def test_class_pattern_not_type(self): def test_class_or_union_not_specialform(self): from typing import Literal name = type(Literal).__name__ - msg = rf"called match pattern must be a class or typing.Union of classes \(got {name}\)" + msg = rf"called match pattern must be a class or a union of classes \(got {name}\)" w = None with self.assertRaisesRegex(TypeError, msg): match 1: From 0febb806bca93ad8fadc550c13857b70c83beedb Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 21 Jul 2025 15:20:44 +0200 Subject: [PATCH 7/7] added to docs --- Doc/reference/compound_stmts.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Doc/reference/compound_stmts.rst b/Doc/reference/compound_stmts.rst index e95fa3a6424e23..36bd911c05f09a 100644 --- a/Doc/reference/compound_stmts.rst +++ b/Doc/reference/compound_stmts.rst @@ -1098,6 +1098,11 @@ The same keyword should not be repeated in class patterns. The following is the logical flow for matching a class pattern against a subject value: +#. If ``name_or_attr`` is a union type, apply the subsequent steps in order to + each of its members, returning the first successful match or raising the first + encountered exception. + This mirrors the behavior of :func:`isinstance` with union types. + #. If ``name_or_attr`` is not an instance of the builtin :class:`type` , raise :exc:`TypeError`.