diff --git a/tests/snippets/dict_union.py b/tests/snippets/dict_union.py new file mode 100644 index 0000000000..29e0718d45 --- /dev/null +++ b/tests/snippets/dict_union.py @@ -0,0 +1,83 @@ + +import testutils + +def test_dunion_ior0(): + a={1:2,2:3} + b={3:4,5:6} + a|=b + + assert a == {1:2,2:3,3:4,5:6}, f"wrong value assigned {a=}" + assert b == {3:4,5:6}, f"right hand side modified, {b=}" + +def test_dunion_or0(): + a={1:2,2:3} + b={3:4,5:6} + c=a|b + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_or1(): + a={1:2,2:3} + b={3:4,5:6} + c=a.__or__(b) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_ror0(): + a={1:2,2:3} + b={3:4,5:6} + c=b.__ror__(a) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_other_types(): + def perf_test_or(other_obj): + d={1:2} + try: + d.__or__(other_obj) + except: + return True + return False + + def perf_test_ior(other_obj): + d={1:2} + try: + d.__ior__(other_obj) + except: + return True + return False + + def perf_test_ror(other_obj): + d={1:2} + try: + d.__ror__(other_obj) + except: + return True + return False + + test_fct={'__or__':perf_test_or, '__ror__':perf_test_ror, '__ior__':perf_test_ior} + others=['FooBar', 42, [36], set([19]), ['aa'], None] + for tfn,tf in test_fct.items(): + for other in others: + assert tf(other), f"Failed: dict {tfn}, accepted {other}" + + + + +testutils.skip_if_unsupported(3,9,test_dunion_ior0) +testutils.skip_if_unsupported(3,9,test_dunion_or0) +testutils.skip_if_unsupported(3,9,test_dunion_or1) +testutils.skip_if_unsupported(3,9,test_dunion_ror0) +testutils.skip_if_unsupported(3,9,test_dunion_other_types) + + + diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index 8a9fdddb2f..c779d2c898 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -1,3 +1,6 @@ +import platform +import sys + def assert_raises(expected, *args, _msg=None, **kw): if args: f, f_args = args[0], args[1:] @@ -67,3 +70,26 @@ def assert_isinstance(obj, klass): def assert_in(a, b): _assert_print(lambda: a in b, [a, 'in', b]) + +def skip_if_unsupported(req_maj_vers, req_min_vers, test_fct): + def exec(): + test_fct() + + if platform.python_implementation() == 'RustPython': + exec() + elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers: + exec() + else: + print(f'Skipping test as a higher python version is required. Using {platform.python_implementation()} {platform.python_version()}') + +def fail_if_unsupported(req_maj_vers, req_min_vers, test_fct): + def exec(): + test_fct() + + if platform.python_implementation() == 'RustPython': + exec() + elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers: + exec() + else: + assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}' + diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index a5330bdbcf..e2cdb108db 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -104,6 +104,17 @@ impl PyDictRef { Ok(()) } + fn merge_dict( + dict: &DictContentType, + dict_other: PyDictRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + for (key, value) in dict_other { + dict.insert(vm, &key, value)?; + } + Ok(()) + } + #[pyclassmethod] fn fromkeys( class: PyClassRef, @@ -320,6 +331,38 @@ impl PyDictRef { PyDictRef::merge(&self.entries, dict_obj, kwargs, vm) } + #[pymethod(name = "__ior__")] + fn ior(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + PyDictRef::merge_dict(&self.entries, other, vm)?; + return Ok(self.into_object()); + } + Err(vm.new_type_error("__ior__ not implemented for non-dict type".to_owned())) + } + + #[pymethod(name = "__ror__")] + fn ror(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + let other_cp = other.copy(); + PyDictRef::merge_dict(&other_cp.entries, self, vm)?; + return Ok(other_cp); + } + Err(vm.new_type_error("__ror__ not implemented for non-dict type".to_owned())) + } + + #[pymethod(name = "__or__")] + fn or(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + let self_cp = self.copy(); + PyDictRef::merge_dict(&self_cp.entries, other, vm)?; + return Ok(self_cp); + } + Err(vm.new_type_error("__or__ not implemented for non-dict type".to_owned())) + } + #[pymethod] fn pop( self,