Skip to content

Commit ebf6b51

Browse files
authored
Merge pull request #1911 from TheAnyKey/TheAnyKey/p39_dict_union_pr
Implement P3.9 style dict union (PEP584)
2 parents 15c88c7 + da09370 commit ebf6b51

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

tests/snippets/dict_union.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
2+
import testutils
3+
4+
def test_dunion_ior0():
5+
a={1:2,2:3}
6+
b={3:4,5:6}
7+
a|=b
8+
9+
assert a == {1:2,2:3,3:4,5:6}, f"wrong value assigned {a=}"
10+
assert b == {3:4,5:6}, f"right hand side modified, {b=}"
11+
12+
def test_dunion_or0():
13+
a={1:2,2:3}
14+
b={3:4,5:6}
15+
c=a|b
16+
17+
assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}"
18+
assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}"
19+
assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}"
20+
21+
22+
def test_dunion_or1():
23+
a={1:2,2:3}
24+
b={3:4,5:6}
25+
c=a.__or__(b)
26+
27+
assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}"
28+
assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}"
29+
assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}"
30+
31+
32+
def test_dunion_ror0():
33+
a={1:2,2:3}
34+
b={3:4,5:6}
35+
c=b.__ror__(a)
36+
37+
assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}"
38+
assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}"
39+
assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}"
40+
41+
42+
def test_dunion_other_types():
43+
def perf_test_or(other_obj):
44+
d={1:2}
45+
try:
46+
d.__or__(other_obj)
47+
except:
48+
return True
49+
return False
50+
51+
def perf_test_ior(other_obj):
52+
d={1:2}
53+
try:
54+
d.__ior__(other_obj)
55+
except:
56+
return True
57+
return False
58+
59+
def perf_test_ror(other_obj):
60+
d={1:2}
61+
try:
62+
d.__ror__(other_obj)
63+
except:
64+
return True
65+
return False
66+
67+
test_fct={'__or__':perf_test_or, '__ror__':perf_test_ror, '__ior__':perf_test_ior}
68+
others=['FooBar', 42, [36], set([19]), ['aa'], None]
69+
for tfn,tf in test_fct.items():
70+
for other in others:
71+
assert tf(other), f"Failed: dict {tfn}, accepted {other}"
72+
73+
74+
75+
76+
testutils.skip_if_unsupported(3,9,test_dunion_ior0)
77+
testutils.skip_if_unsupported(3,9,test_dunion_or0)
78+
testutils.skip_if_unsupported(3,9,test_dunion_or1)
79+
testutils.skip_if_unsupported(3,9,test_dunion_ror0)
80+
testutils.skip_if_unsupported(3,9,test_dunion_other_types)
81+
82+
83+

tests/snippets/testutils.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import platform
2+
import sys
3+
14
def assert_raises(expected, *args, _msg=None, **kw):
25
if args:
36
f, f_args = args[0], args[1:]
@@ -67,3 +70,26 @@ def assert_isinstance(obj, klass):
6770

6871
def assert_in(a, b):
6972
_assert_print(lambda: a in b, [a, 'in', b])
73+
74+
def skip_if_unsupported(req_maj_vers, req_min_vers, test_fct):
75+
def exec():
76+
test_fct()
77+
78+
if platform.python_implementation() == 'RustPython':
79+
exec()
80+
elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers:
81+
exec()
82+
else:
83+
print(f'Skipping test as a higher python version is required. Using {platform.python_implementation()} {platform.python_version()}')
84+
85+
def fail_if_unsupported(req_maj_vers, req_min_vers, test_fct):
86+
def exec():
87+
test_fct()
88+
89+
if platform.python_implementation() == 'RustPython':
90+
exec()
91+
elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers:
92+
exec()
93+
else:
94+
assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}'
95+

vm/src/obj/objdict.rs

+43
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ impl PyDictRef {
104104
Ok(())
105105
}
106106

107+
fn merge_dict(
108+
dict: &DictContentType,
109+
dict_other: PyDictRef,
110+
vm: &VirtualMachine,
111+
) -> PyResult<()> {
112+
for (key, value) in dict_other {
113+
dict.insert(vm, &key, value)?;
114+
}
115+
Ok(())
116+
}
117+
107118
#[pyclassmethod]
108119
fn fromkeys(
109120
class: PyClassRef,
@@ -320,6 +331,38 @@ impl PyDictRef {
320331
PyDictRef::merge(&self.entries, dict_obj, kwargs, vm)
321332
}
322333

334+
#[pymethod(name = "__ior__")]
335+
fn ior(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
336+
let dicted: Result<PyDictRef, _> = other.clone().downcast();
337+
if let Ok(other) = dicted {
338+
PyDictRef::merge_dict(&self.entries, other, vm)?;
339+
return Ok(self.into_object());
340+
}
341+
Err(vm.new_type_error("__ior__ not implemented for non-dict type".to_owned()))
342+
}
343+
344+
#[pymethod(name = "__ror__")]
345+
fn ror(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyDict> {
346+
let dicted: Result<PyDictRef, _> = other.clone().downcast();
347+
if let Ok(other) = dicted {
348+
let other_cp = other.copy();
349+
PyDictRef::merge_dict(&other_cp.entries, self, vm)?;
350+
return Ok(other_cp);
351+
}
352+
Err(vm.new_type_error("__ror__ not implemented for non-dict type".to_owned()))
353+
}
354+
355+
#[pymethod(name = "__or__")]
356+
fn or(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyDict> {
357+
let dicted: Result<PyDictRef, _> = other.clone().downcast();
358+
if let Ok(other) = dicted {
359+
let self_cp = self.copy();
360+
PyDictRef::merge_dict(&self_cp.entries, other, vm)?;
361+
return Ok(self_cp);
362+
}
363+
Err(vm.new_type_error("__or__ not implemented for non-dict type".to_owned()))
364+
}
365+
323366
#[pymethod]
324367
fn pop(
325368
self,

0 commit comments

Comments
 (0)