diff --git a/tests/snippets/dict.py b/tests/snippets/dict.py new file mode 100644 index 0000000000..ca9d61aa68 --- /dev/null +++ b/tests/snippets/dict.py @@ -0,0 +1,16 @@ +def dict_eq(d1, d2): + return (all(k in d2 and d1[k] == d2[k] for k in d1) + and all(k in d1 and d1[k] == d2[k] for k in d2)) + + +assert dict_eq(dict(a=2, b=3), {'a': 2, 'b': 3}) +assert dict_eq(dict({'a': 2, 'b': 3}, b=4), {'a': 2, 'b': 4}) +assert dict_eq(dict([('a', 2), ('b', 3)]), {'a': 2, 'b': 3}) + +a = {'g': 5} +b = {'a': a, 'd': 9} +c = dict(b) +c['d'] = 3 +c['a']['g'] = 2 +assert dict_eq(a, {'g': 2}) +assert dict_eq(b, {'a': a, 'd': 9}) diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index ef13e02762..16b1890b04 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -2,6 +2,7 @@ use super::super::pyobject::{ PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, }; use super::super::vm::VirtualMachine; +use super::objiter; use super::objstr; use super::objtype; use num_bigint::ToBigInt; @@ -113,8 +114,43 @@ pub fn content_contains_key_str(elements: &DictContentType, key: &str) -> bool { // Python dict methods: -fn dict_new(_vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - Ok(new(args.args[0].clone())) +fn dict_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(_ty, Some(vm.ctx.type_type()))], + optional = [(dict_obj, None)] + ); + let dict = vm.ctx.new_dict(); + if let Some(dict_obj) = dict_obj { + if objtype::isinstance(&dict_obj, &vm.ctx.dict_type()) { + for (needle, value) in get_key_value_pairs(&dict_obj) { + set_item(&dict, &needle, &value); + } + } else { + let iter = objiter::get_iter(vm, dict_obj)?; + loop { + fn err(vm: &mut VirtualMachine) -> PyObjectRef { + vm.new_type_error("Iterator must have exactly two elements".to_string()) + } + let element = match objiter::get_next_object(vm, &iter)? { + Some(obj) => obj, + None => break, + }; + let elem_iter = objiter::get_iter(vm, &element)?; + let needle = objiter::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; + let value = objiter::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; + if let Some(_) = objiter::get_next_object(vm, &elem_iter)? { + return Err(err(vm)); + } + set_item(&dict, &needle, &value); + } + } + } + for (needle, value) in args.kwargs { + set_item(&dict, &vm.new_str(needle), &value); + } + Ok(dict) } fn dict_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {