diff --git a/tests/snippets/numbers.py b/tests/snippets/numbers.py index 7b01f6473c..da4ff50400 100644 --- a/tests/snippets/numbers.py +++ b/tests/snippets/numbers.py @@ -9,6 +9,23 @@ class A(int): assert x == 7 assert type(x) is A +assert int(2).__index__() == 2 +assert int(2).__trunc__() == 2 +assert int(2).__ceil__() == 2 +assert int(2).__floor__() == 2 +assert int(2).__round__() == 2 +assert int(2).__round__(3) == 2 +assert int(-2).__index__() == -2 +assert int(-2).__trunc__() == -2 +assert int(-2).__ceil__() == -2 +assert int(-2).__floor__() == -2 +assert int(-2).__round__() == -2 +assert int(-2).__round__(3) == -2 + +assert round(10) == 10 +assert round(10, 2) == 10 +assert round(10, -1) == 10 + assert int(2).__bool__() == True assert int(0.5).__bool__() == False assert int(-1).__bool__() == True diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index a08e70fb70..31f377b28b 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -664,7 +664,24 @@ fn builtin_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { vm.to_repr(obj) } // builtin_reversed -// builtin_round + +fn builtin_round(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(number, Some(vm.ctx.object()))], + optional = [(ndigits, None)] + ); + if let Some(ndigits) = ndigits { + let ndigits = vm.call_method(ndigits, "__int__", vec![])?; + let rounded = vm.call_method(number, "__round__", vec![ndigits])?; + Ok(rounded) + } else { + // without a parameter, the result type is coerced to int + let rounded = &vm.call_method(number, "__round__", vec![])?; + Ok(vm.ctx.new_int(objint::get_value(rounded))) + } +} fn builtin_setattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( @@ -777,6 +794,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "property", ctx.property_type()); ctx.set_attr(&py_mod, "range", ctx.range_type()); ctx.set_attr(&py_mod, "repr", ctx.new_rustfunc(builtin_repr)); + ctx.set_attr(&py_mod, "round", ctx.new_rustfunc(builtin_round)); ctx.set_attr(&py_mod, "set", ctx.set_type()); ctx.set_attr(&py_mod, "setattr", ctx.new_rustfunc(builtin_setattr)); ctx.set_attr(&py_mod, "staticmethod", ctx.staticmethod_type()); diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index a00f1d5b0f..75e3b3dec1 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -294,6 +294,21 @@ fn int_floordiv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn int_round(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(i, Some(vm.ctx.int_type()))], + optional = [(_precision, None)] + ); + Ok(vm.ctx.new_int(get_value(i))) +} + +fn int_pass_value(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(i, Some(vm.ctx.int_type()))]); + Ok(vm.ctx.new_int(get_value(i))) +} + fn int_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -512,6 +527,12 @@ pub fn init(context: &PyContext) { context.set_attr(&int_type, "__and__", context.new_rustfunc(int_and)); context.set_attr(&int_type, "__divmod__", context.new_rustfunc(int_divmod)); context.set_attr(&int_type, "__float__", context.new_rustfunc(int_float)); + context.set_attr(&int_type, "__round__", context.new_rustfunc(int_round)); + context.set_attr(&int_type, "__ceil__", context.new_rustfunc(int_pass_value)); + context.set_attr(&int_type, "__floor__", context.new_rustfunc(int_pass_value)); + context.set_attr(&int_type, "__index__", context.new_rustfunc(int_pass_value)); + context.set_attr(&int_type, "__trunc__", context.new_rustfunc(int_pass_value)); + context.set_attr(&int_type, "__int__", context.new_rustfunc(int_pass_value)); context.set_attr( &int_type, "__floordiv__",