Skip to content

gh-91603: Speed up UnionType instantiation #91865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Speed up :class:`types.UnionType` instantiation. Patch provided by Yurii
Karabas.
235 changes: 148 additions & 87 deletions Objects/unionobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,102 +137,143 @@ union_richcompare(PyObject *a, PyObject *b, int op)
return result;
}

static PyObject*
flatten_args(PyObject* args)
static int
is_same(PyObject* left, PyObject* right)
{
Py_ssize_t arg_length = PyTuple_GET_SIZE(args);
Py_ssize_t total_args = 0;
// Get number of total args once it's flattened.
for (Py_ssize_t i = 0; i < arg_length; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
if (_PyUnion_Check(arg)) {
total_args += PyTuple_GET_SIZE(((unionobject*) arg)->args);
} else {
total_args++;
int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
}

static int
is_unionable(PyObject *obj)
{
return (obj == Py_None ||
PyType_Check(obj) ||
_PyGenericAlias_Check(obj) ||
_PyUnion_Check(obj));
}

static int
args_contains(PyObject *args, PyObject *obj)
{
assert(PyTuple_CheckExact(args));
Py_ssize_t size = PyTuple_GET_SIZE(args);

for (int j = 0; j < size; j++) {
PyObject *left_arg = PyTuple_GET_ITEM(args, j);
int is_duplicate = is_same(left_arg, obj);
if (is_duplicate) {
return is_duplicate;
}
}
// Create new tuple of flattened args.
PyObject *flattened_args = PyTuple_New(total_args);
if (flattened_args == NULL) {

return 0;
}

static PyObject*
merge_union_and_union(PyObject *left, PyObject *right)
{
PyObject* left_args = ((unionobject *) left)->args;
PyObject* right_args = ((unionobject*) right)->args;
Py_ssize_t left_size = PyTuple_GET_SIZE(left_args);
Py_ssize_t right_size = PyTuple_GET_SIZE(right_args);
PyObject *tuple = PyTuple_New(left_size + right_size);

if (tuple == NULL) {
return NULL;
}
Py_ssize_t pos = 0;
for (Py_ssize_t i = 0; i < arg_length; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
if (_PyUnion_Check(arg)) {
PyObject* nested_args = ((unionobject*)arg)->args;
Py_ssize_t nested_arg_length = PyTuple_GET_SIZE(nested_args);
for (Py_ssize_t j = 0; j < nested_arg_length; j++) {
PyObject* nested_arg = PyTuple_GET_ITEM(nested_args, j);
Py_INCREF(nested_arg);
PyTuple_SET_ITEM(flattened_args, pos, nested_arg);
pos++;
}
} else {
if (arg == Py_None) {
arg = (PyObject *)&_PyNone_Type;
}

for (int i = 0; i < left_size; i++) {
PyObject *arg = PyTuple_GET_ITEM(left_args, i);
Py_INCREF(arg);
PyTuple_SET_ITEM(tuple, i, arg);
}

Py_ssize_t pos = left_size;
for (int i = 0; i < right_size; i++) {
PyObject *arg = PyTuple_GET_ITEM(right_args, i);
int is_duplicate = args_contains(left_args, arg);

if (is_duplicate < 0) {
Py_DECREF(tuple);
return NULL;
}
if (!is_duplicate) {
Py_INCREF(arg);
PyTuple_SET_ITEM(flattened_args, pos, arg);
PyTuple_SET_ITEM(tuple, pos, arg);
pos++;
}
}
assert(pos == total_args);
return flattened_args;

_PyTuple_Resize(&tuple, pos);
return tuple;
}

static PyObject*
dedup_and_flatten_args(PyObject* args)
merge_union_and_obj(PyObject *left, PyObject *right)
{
args = flatten_args(args);
if (args == NULL) {
PyObject* args = ((unionobject *) left)->args;
int is_duplicate = args_contains(args, right);

if (is_duplicate < 0) {
return NULL;
}
Py_ssize_t arg_length = PyTuple_GET_SIZE(args);
PyObject *new_args = PyTuple_New(arg_length);
if (new_args == NULL) {
Py_DECREF(args);
if (is_duplicate) {
Py_INCREF(args);
return args;
}

Py_ssize_t size = PyTuple_GET_SIZE(args);
PyObject *tuple = PyTuple_New(size + 1);

if (tuple == NULL) {
return NULL;
}
// Add unique elements to an array.
Py_ssize_t added_items = 0;
for (Py_ssize_t i = 0; i < arg_length; i++) {
int is_duplicate = 0;
PyObject* i_element = PyTuple_GET_ITEM(args, i);
for (Py_ssize_t j = 0; j < added_items; j++) {
PyObject* j_element = PyTuple_GET_ITEM(new_args, j);
int is_ga = _PyGenericAlias_Check(i_element) &&
_PyGenericAlias_Check(j_element);
// RichCompare to also deduplicate GenericAlias types (slower)
is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ)
: i_element == j_element;
// Should only happen if RichCompare fails
if (is_duplicate < 0) {
Py_DECREF(args);
Py_DECREF(new_args);
return NULL;
}
if (is_duplicate)
break;

for (int i = 0; i < size; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
Py_INCREF(arg);
PyTuple_SET_ITEM(tuple, i, arg);
}

Py_INCREF(right);
PyTuple_SET_ITEM(tuple, size, right);

return tuple;
}

static PyObject*
merge_obj_and_union(PyObject *left, PyObject *right)
{
PyObject* args = ((unionobject *) right)->args;
Py_ssize_t size = PyTuple_GET_SIZE(args);
PyObject *tuple = PyTuple_New(size + 1);

if (tuple == NULL) {
return NULL;
}

Py_INCREF(left);
PyTuple_SET_ITEM(tuple, 0, left);
Py_ssize_t pos = 1;

for (int i = 0; i < size; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
int is_duplicate = is_same(left, arg);

if (is_duplicate < 0) {
Py_DECREF(tuple);
return NULL;
}
if (!is_duplicate) {
Py_INCREF(i_element);
PyTuple_SET_ITEM(new_args, added_items, i_element);
added_items++;
Py_INCREF(arg);
PyTuple_SET_ITEM(tuple, pos, arg);
pos++;
}
}
Py_DECREF(args);
_PyTuple_Resize(&new_args, added_items);
return new_args;
}

static int
is_unionable(PyObject *obj)
{
return (obj == Py_None ||
PyType_Check(obj) ||
_PyGenericAlias_Check(obj) ||
_PyUnion_Check(obj));
_PyTuple_Resize(&tuple, pos);
return tuple;
}

PyObject *
Expand All @@ -242,7 +283,38 @@ _Py_union_type_or(PyObject* self, PyObject* other)
Py_RETURN_NOTIMPLEMENTED;
}

PyObject *tuple = PyTuple_Pack(2, self, other);
if (self == Py_None) {
self = (PyObject *)&_PyNone_Type;
}
if (other == Py_None) {
other = (PyObject *)&_PyNone_Type;
}

PyObject *tuple;

if (_PyUnion_Check(self) && _PyUnion_Check(other)) {
tuple = merge_union_and_union(self, other);
}
else if (_PyUnion_Check(self)) {
tuple = merge_union_and_obj(self, other);
}
else if (_PyUnion_Check(other)) {
tuple = merge_obj_and_union(self, other);
}
else {
int is_duplicate = is_same(self, other);

if (is_duplicate < 0) {
return NULL;
}
if (is_duplicate) {
Py_INCREF(self);
return self;
}

tuple = PyTuple_Pack(2, self, other);
}

if (tuple == NULL) {
return NULL;
}
Expand Down Expand Up @@ -468,23 +540,12 @@ make_union(PyObject *args)
{
assert(PyTuple_CheckExact(args));

args = dedup_and_flatten_args(args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need this logic for the callsite in union_getitem.

Consider this case:

>>> from typing import TypeVar
>>> T = TypeVar("T")
>>> (list[T] | list[int])
list[~T] | list[int]
>>> (list[T] | list[int])[int]
list[int]

If I'm reading your code correctly, it would no longer deduplicate the two members. It would be good to also add a test case based on this example.

We should probably put the deduping logic directly in union_getitem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JelleZijlstra Actually, we don't need it, as far as such case already handled by tests:

self.assertEqual((list[T] | list[S])[int, int], list[int])

That's because after parameter substitution new UnionType is recreated by reducing newargs using bitwise or operator:

res = PyTuple_GET_ITEM(newargs, 0);
Py_INCREF(res);
for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
Py_SETREF(res, PyNumber_Or(res, arg));
if (res == NULL) {
break;
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good catch! That seems a bit inefficient too, but we don't need to fix that in this PR; in any case it's probably a less performance-critical path.

if (args == NULL) {
return NULL;
}
if (PyTuple_GET_SIZE(args) == 1) {
PyObject *result1 = PyTuple_GET_ITEM(args, 0);
Py_INCREF(result1);
Py_DECREF(args);
return result1;
}

unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
if (result == NULL) {
Py_DECREF(args);
return NULL;
}

Py_INCREF(args);
result->parameters = NULL;
result->args = args;
_PyObject_GC_TRACK(result);
Expand Down