diff --git a/vm/src/macros.rs b/vm/src/macros.rs index a631867976..47d3c531ab 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -147,3 +147,89 @@ macro_rules! extend_class { )* } } + +/// Macro to match on the built-in class of a Python object. +/// +/// Like `match`, `match_class!` must be exhaustive, so a default arm with +/// the uncasted object is required. +/// +/// # Examples +/// +/// ``` +/// use num_bigint::ToBigInt; +/// use num_traits::Zero; +/// +/// use rustpython_vm::VirtualMachine; +/// use rustpython_vm::match_class; +/// use rustpython_vm::obj::objfloat::PyFloat; +/// use rustpython_vm::obj::objint::PyInt; +/// use rustpython_vm::pyobject::PyValue; +/// +/// let vm = VirtualMachine::new(); +/// let obj = PyInt::new(0).into_ref(&vm).into_object(); +/// assert_eq!( +/// "int", +/// match_class!(obj.clone(), +/// PyInt => "int", +/// PyFloat => "float", +/// _ => "neither", +/// ) +/// ); +/// +/// ``` +/// +/// With a binding to the downcasted type: +/// +/// ``` +/// use num_bigint::ToBigInt; +/// use num_traits::Zero; +/// +/// use rustpython_vm::VirtualMachine; +/// use rustpython_vm::match_class; +/// use rustpython_vm::obj::objfloat::PyFloat; +/// use rustpython_vm::obj::objint::PyInt; +/// use rustpython_vm::pyobject::PyValue; +/// +/// let vm = VirtualMachine::new(); +/// let obj = PyInt::new(0).into_ref(&vm).into_object(); +/// +/// let int_value = match_class!(obj, +/// i @ PyInt => i.as_bigint().clone(), +/// f @ PyFloat => f.to_f64().to_bigint().unwrap(), +/// obj => panic!("non-numeric object {}", obj), +/// ); +/// +/// assert!(int_value.is_zero()); +/// ``` +#[macro_export] +macro_rules! match_class { + // The default arm. + ($obj:expr, _ => $default:expr $(,)?) => { + $default + }; + + // The default arm, binding the original object to the specified identifier. + ($obj:expr, $binding:ident => $default:expr $(,)?) => {{ + let $binding = $obj; + $default + }}; + + // An arm taken when the object is an instance of the specified built-in + // class and binding the downcasted object to the specified identifier. + ($obj:expr, $binding:ident @ $class:ty => $expr:expr, $($rest:tt)*) => { + match $obj.downcast::<$class>() { + Ok($binding) => $expr, + Err(_obj) => match_class!(_obj, $($rest)*), + } + }; + + // An arm taken when the object is an instance of the specified built-in + // class. + ($obj:expr, $class:ty => $expr:expr, $($rest:tt)*) => { + if $obj.payload_is::<$class>() { + $expr + } else { + match_class!($obj, $($rest)*) + } + }; +} diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index ca6bad9d22..4ef7123eee 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -16,6 +16,12 @@ pub struct PyFloat { value: f64, } +impl PyFloat { + pub fn to_f64(&self) -> f64 { + self.value + } +} + impl PyValue for PyFloat { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.float_type() diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 08175dc07f..d7b6a04206 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -12,8 +12,8 @@ use crate::pyobject::{ }; use crate::vm::VirtualMachine; -use super::objfloat; -use super::objstr; +use super::objfloat::{self, PyFloat}; +use super::objstr::{PyString, PyStringRef}; use super::objtype; use crate::obj::objtype::PyClassRef; @@ -351,7 +351,7 @@ impl PyIntRef { self.value.to_string() } - fn format(self, spec: PyRef, vm: &VirtualMachine) -> PyResult { + fn format(self, spec: PyStringRef, vm: &VirtualMachine) -> PyResult { let format_spec = FormatSpec::parse(&spec.value); match format_spec.format_int(&self.value) { Ok(string) => Ok(string), @@ -407,30 +407,24 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul } // Casting function: +// TODO: this should just call `__int__` on the object pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult { - let val = if objtype::isinstance(obj, &vm.ctx.int_type()) { - get_value(obj).clone() - } else if objtype::isinstance(obj, &vm.ctx.float_type()) { - objfloat::get_value(obj).to_bigint().unwrap() - } else if objtype::isinstance(obj, &vm.ctx.str_type()) { - let s = objstr::get_value(obj); - match i32::from_str_radix(&s, base) { - Ok(v) => v.to_bigint().unwrap(), - Err(err) => { - trace!("Error occurred during int conversion {:?}", err); - return Err(vm.new_value_error(format!( + match_class!(obj.clone(), + i @ PyInt => Ok(i.as_bigint().clone()), + f @ PyFloat => Ok(f.to_f64().to_bigint().unwrap()), + s @ PyString => { + i32::from_str_radix(s.as_str(), base) + .map(|i| BigInt::from(i)) + .map_err(|_|vm.new_value_error(format!( "invalid literal for int() with base {}: '{}'", base, s - ))); - } - } - } else { - return Err(vm.new_type_error(format!( + ))) + }, + obj => Err(vm.new_type_error(format!( "int() argument must be a string or a number, not '{}'", obj.class().name - ))); - }; - Ok(val) + ))) + ) } // Retrieve inner int value: diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index aa62812a24..8ac23c1607 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -25,6 +25,13 @@ pub struct PyString { // TODO: shouldn't be public pub value: String, } + +impl PyString { + pub fn as_str(&self) -> &str { + &self.value + } +} + pub type PyStringRef = PyRef; impl fmt::Display for PyString {