diff --git a/tests/snippets/builtin_format.py b/tests/snippets/builtin_format.py new file mode 100644 index 0000000000..bb7e554b7a --- /dev/null +++ b/tests/snippets/builtin_format.py @@ -0,0 +1,8 @@ +assert format(5, "b") == "101" + +try: + format(2, 3) +except TypeError: + pass +else: + assert False, "TypeError not raised when format is called with a number" diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index dc82033b7f..260df02345 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -40,3 +40,9 @@ c = 'hallo' assert c.capitalize() == 'Hallo' + +# String Formatting +assert "{} {}".format(1,2) == "1 2" +assert "{0} {1}".format(2,3) == "2 3" +assert "--{:s>4}--".format(1) == "--sss1--" +assert "{keyword} {0}".format(1, keyword=2) == "2 1" diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 292107fd75..2614232e09 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -317,7 +317,15 @@ fn builtin_filter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_list(new_items)) } -// builtin_format +fn builtin_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(obj, None), (format_spec, Some(vm.ctx.str_type()))] + ); + + vm.call_method(obj, "__format__", vec![format_spec.clone()]) +} fn builtin_getattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( @@ -756,6 +764,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "float", ctx.float_type()); ctx.set_attr(&py_mod, "frozenset", ctx.frozenset_type()); ctx.set_attr(&py_mod, "filter", ctx.new_rustfunc(builtin_filter)); + ctx.set_attr(&py_mod, "format", ctx.new_rustfunc(builtin_format)); ctx.set_attr(&py_mod, "getattr", ctx.new_rustfunc(builtin_getattr)); ctx.set_attr(&py_mod, "hasattr", ctx.new_rustfunc(builtin_hasattr)); ctx.set_attr(&py_mod, "hash", ctx.new_rustfunc(builtin_hash)); diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index df37b48d26..fb38a76978 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -86,6 +86,8 @@ pub struct ExceptionZoo { pub syntax_error: PyObjectRef, pub assertion_error: PyObjectRef, pub attribute_error: PyObjectRef, + pub index_error: PyObjectRef, + pub key_error: PyObjectRef, pub name_error: PyObjectRef, pub runtime_error: PyObjectRef, pub not_implemented_error: PyObjectRef, @@ -129,6 +131,18 @@ impl ExceptionZoo { &exception_type.clone(), &dict_type, ); + let index_error = create_type( + &String::from("IndexError"), + &type_type, + &exception_type.clone(), + &dict_type, + ); + let key_error = create_type( + &String::from("KeyError"), + &type_type, + &exception_type.clone(), + &dict_type, + ); let name_error = create_type( &String::from("NameError"), &type_type, @@ -184,6 +198,8 @@ impl ExceptionZoo { syntax_error: syntax_error, assertion_error: assertion_error, attribute_error: attribute_error, + index_error: index_error, + key_error: key_error, name_error: name_error, runtime_error: runtime_error, not_implemented_error: not_implemented_error, diff --git a/vm/src/format.rs b/vm/src/format.rs new file mode 100644 index 0000000000..37f8bfa57b --- /dev/null +++ b/vm/src/format.rs @@ -0,0 +1,658 @@ +use num_bigint::{BigInt, Sign}; +use num_traits::Signed; +use std::cmp; +use std::str::FromStr; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum FormatAlign { + Left, + Right, + AfterSign, + Center, +} + +impl FormatAlign { + fn from_char(c: char) -> Option { + match c { + '<' => Some(FormatAlign::Left), + '>' => Some(FormatAlign::Right), + '=' => Some(FormatAlign::AfterSign), + '^' => Some(FormatAlign::Center), + _ => None, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum FormatSign { + Plus, + Minus, + MinusOrSpace, +} + +#[derive(Debug, PartialEq)] +pub enum FormatGrouping { + Comma, + Underscore, +} + +#[derive(Debug, PartialEq)] +pub enum FormatType { + String, + Binary, + Character, + Decimal, + Octal, + HexLower, + HexUpper, + Number, + ExponentLower, + ExponentUpper, + GeneralFormatLower, + GeneralFormatUpper, + FixedPointLower, + FixedPointUpper, +} + +#[derive(Debug, PartialEq)] +pub struct FormatSpec { + fill: Option, + align: Option, + sign: Option, + alternate_form: bool, + width: Option, + grouping_option: Option, + precision: Option, + format_type: Option, +} + +fn get_num_digits(text: &str) -> usize { + for (index, character) in text.char_indices() { + if !character.is_digit(10) { + return index; + } + } + text.len() +} + +fn parse_align(text: &str) -> (Option, &str) { + let mut chars = text.chars(); + let maybe_align = chars.next().and_then(FormatAlign::from_char); + if maybe_align.is_some() { + (maybe_align, &chars.as_str()) + } else { + (None, text) + } +} + +fn parse_fill_and_align(text: &str) -> (Option, Option, &str) { + let char_indices: Vec<(usize, char)> = text.char_indices().take(3).collect(); + if char_indices.len() == 0 { + (None, None, text) + } else if char_indices.len() == 1 { + let (maybe_align, remaining) = parse_align(text); + (None, maybe_align, remaining) + } else { + let (maybe_align, remaining) = parse_align(&text[char_indices[1].0..]); + if maybe_align.is_some() { + (Some(char_indices[0].1), maybe_align, remaining) + } else { + let (only_align, only_align_remaining) = parse_align(text); + (None, only_align, only_align_remaining) + } + } +} + +fn parse_number(text: &str) -> (Option, &str) { + let num_digits: usize = get_num_digits(text); + if num_digits == 0 { + return (None, text); + } + // This should never fail + ( + Some(text[..num_digits].parse::().unwrap()), + &text[num_digits..], + ) +} + +fn parse_sign(text: &str) -> (Option, &str) { + let mut chars = text.chars(); + match chars.next() { + Some('-') => (Some(FormatSign::Minus), chars.as_str()), + Some('+') => (Some(FormatSign::Plus), chars.as_str()), + Some(' ') => (Some(FormatSign::MinusOrSpace), chars.as_str()), + _ => (None, text), + } +} + +fn parse_alternate_form(text: &str) -> (bool, &str) { + let mut chars = text.chars(); + match chars.next() { + Some('#') => (true, chars.as_str()), + _ => (false, text), + } +} + +fn parse_zero(text: &str) -> &str { + let mut chars = text.chars(); + match chars.next() { + Some('0') => chars.as_str(), + _ => text, + } +} + +fn parse_precision(text: &str) -> (Option, &str) { + let mut chars = text.chars(); + match chars.next() { + Some('.') => { + let (size, remaining) = parse_number(&chars.as_str()); + if size.is_some() { + (size, remaining) + } else { + (None, text) + } + } + _ => (None, text), + } +} + +fn parse_grouping_option(text: &str) -> (Option, &str) { + let mut chars = text.chars(); + match chars.next() { + Some('_') => (Some(FormatGrouping::Underscore), chars.as_str()), + Some(',') => (Some(FormatGrouping::Comma), chars.as_str()), + _ => (None, text), + } +} + +fn parse_format_type(text: &str) -> (Option, &str) { + let mut chars = text.chars(); + match chars.next() { + Some('b') => (Some(FormatType::Binary), chars.as_str()), + Some('c') => (Some(FormatType::Character), chars.as_str()), + Some('d') => (Some(FormatType::Decimal), chars.as_str()), + Some('o') => (Some(FormatType::Octal), chars.as_str()), + Some('x') => (Some(FormatType::HexLower), chars.as_str()), + Some('X') => (Some(FormatType::HexUpper), chars.as_str()), + Some('e') => (Some(FormatType::ExponentLower), chars.as_str()), + Some('E') => (Some(FormatType::ExponentUpper), chars.as_str()), + Some('f') => (Some(FormatType::FixedPointLower), chars.as_str()), + Some('F') => (Some(FormatType::FixedPointUpper), chars.as_str()), + Some('g') => (Some(FormatType::GeneralFormatLower), chars.as_str()), + Some('G') => (Some(FormatType::GeneralFormatUpper), chars.as_str()), + Some('n') => (Some(FormatType::Number), chars.as_str()), + _ => (None, text), + } +} + +fn parse_format_spec(text: &str) -> FormatSpec { + let (fill, align, after_align) = parse_fill_and_align(text); + let (sign, after_sign) = parse_sign(after_align); + let (alternate_form, after_alternate_form) = parse_alternate_form(after_sign); + let after_zero = parse_zero(after_alternate_form); + let (width, after_width) = parse_number(after_zero); + let (grouping_option, after_grouping_option) = parse_grouping_option(after_width); + let (precision, after_precision) = parse_precision(after_grouping_option); + let (format_type, _) = parse_format_type(after_precision); + + FormatSpec { + fill, + align, + sign, + alternate_form, + width, + grouping_option, + precision, + format_type, + } +} + +impl FormatSpec { + pub fn parse(text: &str) -> FormatSpec { + parse_format_spec(text) + } + + fn compute_fill_string(fill_char: char, fill_chars_needed: i32) -> String { + (0..fill_chars_needed) + .map(|_| fill_char) + .collect::() + } + + fn add_magnitude_separators_for_char( + magnitude_string: String, + interval: usize, + separator: char, + ) -> String { + let mut result = String::new(); + let mut remaining: usize = magnitude_string.len() % interval; + if remaining == 0 { + remaining = interval; + } + for c in magnitude_string.chars() { + result.push(c); + if remaining == 0 { + result.push(separator); + remaining = interval; + } + } + result + } + + fn get_separator_interval(&self) -> usize { + match self.format_type { + Some(FormatType::Binary) => 4, + Some(FormatType::Decimal) => 3, + Some(FormatType::Octal) => 4, + Some(FormatType::HexLower) => 4, + Some(FormatType::HexUpper) => 4, + Some(FormatType::Number) => 3, + None => 3, + _ => panic!("Separators only valid for numbers!"), + } + } + + fn add_magnitude_separators(&self, magnitude_string: String) -> String { + match self.grouping_option { + Some(FormatGrouping::Comma) => FormatSpec::add_magnitude_separators_for_char( + magnitude_string, + self.get_separator_interval(), + ',', + ), + Some(FormatGrouping::Underscore) => FormatSpec::add_magnitude_separators_for_char( + magnitude_string, + self.get_separator_interval(), + '_', + ), + None => magnitude_string, + } + } + + pub fn format_int(&self, num: &BigInt) -> Result { + let fill_char = self.fill.unwrap_or(' '); + let magnitude = num.abs(); + let prefix = if self.alternate_form { + match self.format_type { + Some(FormatType::Binary) => "0b", + Some(FormatType::Octal) => "0o", + Some(FormatType::HexLower) => "0x", + Some(FormatType::HexUpper) => "0x", + _ => "", + } + } else { + "" + }; + let raw_magnitude_string_result: Result = match self.format_type { + Some(FormatType::Binary) => Ok(magnitude.to_str_radix(2)), + Some(FormatType::Decimal) => Ok(magnitude.to_str_radix(10)), + Some(FormatType::Octal) => Ok(magnitude.to_str_radix(8)), + Some(FormatType::HexLower) => Ok(magnitude.to_str_radix(16)), + Some(FormatType::HexUpper) => { + let mut result = magnitude.to_str_radix(16); + result.make_ascii_uppercase(); + Ok(result) + } + Some(FormatType::Number) => Ok(magnitude.to_str_radix(10)), + Some(FormatType::String) => Err("Unknown format code 's' for object of type 'int'"), + Some(FormatType::Character) => Err("Unknown format code 'c' for object of type 'int'"), + Some(FormatType::GeneralFormatUpper) => { + Err("Unknown format code 'G' for object of type 'int'") + } + Some(FormatType::GeneralFormatLower) => { + Err("Unknown format code 'g' for object of type 'int'") + } + Some(FormatType::ExponentUpper) => { + Err("Unknown format code 'E' for object of type 'int'") + } + Some(FormatType::ExponentLower) => { + Err("Unknown format code 'e' for object of type 'int'") + } + Some(FormatType::FixedPointUpper) => { + Err("Unknown format code 'F' for object of type 'int'") + } + Some(FormatType::FixedPointLower) => { + Err("Unknown format code 'f' for object of type 'int'") + } + None => Ok(magnitude.to_str_radix(10)), + }; + if !raw_magnitude_string_result.is_ok() { + return raw_magnitude_string_result; + } + let magnitude_string = format!( + "{}{}", + prefix, + self.add_magnitude_separators(raw_magnitude_string_result.unwrap()) + ); + let align = self.align.unwrap_or(FormatAlign::Right); + + // Use the byte length as the string length since we're in ascii + let num_chars = magnitude_string.len(); + + let format_sign = self.sign.unwrap_or(FormatSign::Minus); + let sign_str = match num.sign() { + Sign::Minus => "-", + _ => match format_sign { + FormatSign::Plus => "+", + FormatSign::Minus => "", + FormatSign::MinusOrSpace => " ", + }, + }; + + let fill_chars_needed: i32 = self.width.map_or(0, |w| { + cmp::max(0, (w as i32) - (num_chars as i32) - (sign_str.len() as i32)) + }); + Ok(match align { + FormatAlign::Left => format!( + "{}{}{}", + sign_str, + magnitude_string, + FormatSpec::compute_fill_string(fill_char, fill_chars_needed) + ), + FormatAlign::Right => format!( + "{}{}{}", + FormatSpec::compute_fill_string(fill_char, fill_chars_needed), + sign_str, + magnitude_string + ), + FormatAlign::AfterSign => format!( + "{}{}{}", + sign_str, + FormatSpec::compute_fill_string(fill_char, fill_chars_needed), + magnitude_string + ), + FormatAlign::Center => { + let left_fill_chars_needed = fill_chars_needed / 2; + let right_fill_chars_needed = fill_chars_needed - left_fill_chars_needed; + let left_fill_string = + FormatSpec::compute_fill_string(fill_char, left_fill_chars_needed); + let right_fill_string = + FormatSpec::compute_fill_string(fill_char, right_fill_chars_needed); + format!( + "{}{}{}{}", + left_fill_string, sign_str, magnitude_string, right_fill_string + ) + } + }) + } +} + +#[derive(Debug, PartialEq)] +pub enum FormatParseError { + UnmatchedBracket, + MissingStartBracket, + UnescapedStartBracketInLiteral, +} + +impl FromStr for FormatSpec { + type Err = &'static str; + fn from_str(s: &str) -> Result { + Ok(FormatSpec::parse(s)) + } +} + +#[derive(Debug, PartialEq)] +pub enum FormatPart { + AutoSpec(String), + IndexSpec(usize, String), + KeywordSpec(String, String), + Literal(String), +} + +impl FormatPart { + pub fn is_auto(&self) -> bool { + match self { + FormatPart::AutoSpec(_) => true, + _ => false, + } + } + + pub fn is_index(&self) -> bool { + match self { + FormatPart::IndexSpec(_, _) => true, + _ => false, + } + } +} + +#[derive(Debug, PartialEq)] +pub struct FormatString { + pub format_parts: Vec, +} + +impl FormatString { + fn parse_literal_single(text: &str) -> Result<(char, &str), FormatParseError> { + let mut chars = text.chars(); + // This should never be called with an empty str + let first_char = chars.next().unwrap(); + if first_char == '{' || first_char == '}' { + let maybe_next_char = chars.next(); + // if we see a bracket, it has to be escaped by doubling up to be in a literal + if maybe_next_char.is_some() && maybe_next_char.unwrap() != first_char { + return Err(FormatParseError::UnescapedStartBracketInLiteral); + } else { + return Ok((first_char, chars.as_str())); + } + } + Ok((first_char, chars.as_str())) + } + + fn parse_literal(text: &str) -> Result<(FormatPart, &str), FormatParseError> { + let mut cur_text = text; + let mut result_string = String::new(); + while cur_text.len() > 0 { + match FormatString::parse_literal_single(cur_text) { + Ok((next_char, remaining)) => { + result_string.push(next_char); + cur_text = remaining; + } + Err(err) => { + if result_string.len() > 0 { + return Ok((FormatPart::Literal(result_string.to_string()), cur_text)); + } else { + return Err(err); + } + } + } + } + Ok((FormatPart::Literal(result_string), "")) + } + + fn parse_part_in_brackets(text: &str) -> Result { + let parts: Vec<&str> = text.splitn(2, ':').collect(); + // before the comma is a keyword or arg index, after the comma is maybe a spec. + let arg_part = parts[0]; + + let format_spec = if parts.len() > 1 { + parts[1].to_string() + } else { + String::new() + }; + + if arg_part.len() == 0 { + return Ok(FormatPart::AutoSpec(format_spec)); + } + + if let Ok(index) = arg_part.parse::() { + Ok(FormatPart::IndexSpec(index, format_spec)) + } else { + Ok(FormatPart::KeywordSpec(arg_part.to_string(), format_spec)) + } + } + + fn parse_spec(text: &str) -> Result<(FormatPart, &str), FormatParseError> { + let mut chars = text.chars(); + if chars.next() != Some('{') { + return Err(FormatParseError::MissingStartBracket); + } + + // Get remaining characters after opening bracket. + let cur_text = chars.as_str(); + // Find the matching bracket and parse the text within for a spec + match cur_text.find('}') { + Some(position) => { + let (left, right) = cur_text.split_at(position); + let format_part = FormatString::parse_part_in_brackets(left)?; + Ok((format_part, &right[1..])) + } + None => Err(FormatParseError::UnmatchedBracket), + } + } + + pub fn from_str(text: &str) -> Result { + let mut cur_text: &str = text; + let mut parts: Vec = Vec::new(); + while cur_text.len() > 0 { + // Try to parse both literals and bracketed format parts util we + // run out of text + cur_text = FormatString::parse_literal(cur_text) + .or_else(|_| FormatString::parse_spec(cur_text)) + .map(|(part, new_text)| { + parts.push(part); + new_text + })?; + } + Ok(FormatString { + format_parts: parts, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fill_and_align() { + assert_eq!( + parse_fill_and_align(" <"), + (Some(' '), Some(FormatAlign::Left), "") + ); + assert_eq!( + parse_fill_and_align(" <22"), + (Some(' '), Some(FormatAlign::Left), "22") + ); + assert_eq!( + parse_fill_and_align("<22"), + (None, Some(FormatAlign::Left), "22") + ); + assert_eq!( + parse_fill_and_align(" ^^"), + (Some(' '), Some(FormatAlign::Center), "^") + ); + assert_eq!( + parse_fill_and_align("==="), + (Some('='), Some(FormatAlign::AfterSign), "=") + ); + } + + #[test] + fn test_width_only() { + let expected = FormatSpec { + fill: None, + align: None, + sign: None, + alternate_form: false, + width: Some(33), + grouping_option: None, + precision: None, + format_type: None, + }; + assert_eq!(parse_format_spec("33"), expected); + } + + #[test] + fn test_fill_and_width() { + let expected = FormatSpec { + fill: Some('<'), + align: Some(FormatAlign::Right), + sign: None, + alternate_form: false, + width: Some(33), + grouping_option: None, + precision: None, + format_type: None, + }; + assert_eq!(parse_format_spec("<>33"), expected); + } + + #[test] + fn test_all() { + let expected = FormatSpec { + fill: Some('<'), + align: Some(FormatAlign::Right), + sign: Some(FormatSign::Minus), + alternate_form: true, + width: Some(23), + grouping_option: Some(FormatGrouping::Comma), + precision: Some(11), + format_type: Some(FormatType::Binary), + }; + assert_eq!(parse_format_spec("<>-#23,.11b"), expected); + } + + #[test] + fn test_format_int() { + assert_eq!( + parse_format_spec("d").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("16".to_string()) + ); + assert_eq!( + parse_format_spec("x").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("10".to_string()) + ); + assert_eq!( + parse_format_spec("b").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("10000".to_string()) + ); + assert_eq!( + parse_format_spec("o").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("20".to_string()) + ); + assert_eq!( + parse_format_spec("+d").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("+16".to_string()) + ); + assert_eq!( + parse_format_spec("^ 5d").format_int(&BigInt::from_bytes_be(Sign::Minus, b"\x10")), + Ok(" -16 ".to_string()) + ); + assert_eq!( + parse_format_spec("0>+#10x").format_int(&BigInt::from_bytes_be(Sign::Plus, b"\x10")), + Ok("00000+0x10".to_string()) + ); + } + + #[test] + fn test_format_parse() { + let expected = Ok(FormatString { + format_parts: vec![ + FormatPart::Literal("abcd".to_string()), + FormatPart::IndexSpec(1, String::new()), + FormatPart::Literal(":".to_string()), + FormatPart::KeywordSpec("key".to_string(), String::new()), + ], + }); + + assert_eq!(FormatString::from_str("abcd{1}:{key}"), expected); + } + + #[test] + fn test_format_parse_fail() { + assert_eq!( + FormatString::from_str("{s"), + Err(FormatParseError::UnmatchedBracket) + ); + } + + #[test] + fn test_format_parse_escape() { + let expected = Ok(FormatString { + format_parts: vec![ + FormatPart::Literal("{".to_string()), + FormatPart::KeywordSpec("key".to_string(), String::new()), + FormatPart::Literal("}ddfe".to_string()), + ], + }); + + assert_eq!(FormatString::from_str("{{{key}}}ddfe"), expected); + } +} diff --git a/vm/src/lib.rs b/vm/src/lib.rs index cfcaf94339..7c1df250ec 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -31,6 +31,7 @@ pub mod bytecode; pub mod compile; pub mod eval; mod exceptions; +pub mod format; mod frame; pub mod import; pub mod obj; diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 1150f4c5a4..7d2470f6aa 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -1,3 +1,4 @@ +use super::super::format::FormatSpec; use super::super::pyobject::{ FromPyObjectRef, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, @@ -220,6 +221,24 @@ fn int_floordiv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn int_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (i, Some(vm.ctx.int_type())), + (format_spec, Some(vm.ctx.str_type())) + ] + ); + let string_value = objstr::get_value(format_spec); + let format_spec = FormatSpec::parse(&string_value); + let int_value = get_value(i); + match format_spec.format_int(&int_value) { + Ok(string) => Ok(vm.ctx.new_str(string)), + Err(err) => Err(vm.new_value_error(err.to_string())), + } +} + fn int_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -413,6 +432,7 @@ pub fn init(context: &PyContext) { context.set_attr(&int_type, "__pow__", context.new_rustfunc(int_pow)); context.set_attr(&int_type, "__repr__", context.new_rustfunc(int_repr)); context.set_attr(&int_type, "__sub__", context.new_rustfunc(int_sub)); + context.set_attr(&int_type, "__format__", context.new_rustfunc(int_format)); context.set_attr(&int_type, "__truediv__", context.new_rustfunc(int_truediv)); context.set_attr(&int_type, "__xor__", context.new_rustfunc(int_xor)); context.set_attr( diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 12e01608a9..76e5c65384 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -1,3 +1,4 @@ +use super::super::format::{FormatParseError, FormatPart, FormatString}; use super::super::pyobject::{ PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, }; @@ -7,7 +8,6 @@ use super::objsequence::PySliceableSequence; use super::objtype; use num_bigint::ToBigInt; use num_traits::ToPrimitive; -use std::fmt; use std::hash::{Hash, Hasher}; pub fn init(context: &PyContext) { @@ -27,6 +27,7 @@ pub fn init(context: &PyContext) { context.set_attr(&str_type, "__new__", context.new_rustfunc(str_new)); context.set_attr(&str_type, "__str__", context.new_rustfunc(str_str)); context.set_attr(&str_type, "__repr__", context.new_rustfunc(str_repr)); + context.set_attr(&str_type, "format", context.new_rustfunc(str_format)); context.set_attr(&str_type, "lower", context.new_rustfunc(str_lower)); context.set_attr(&str_type, "upper", context.new_rustfunc(str_upper)); context.set_attr( @@ -147,6 +148,101 @@ fn str_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn str_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + if args.args.len() == 0 { + return Err( + vm.new_type_error("descriptor 'format' of 'str' object needs an argument".to_string()) + ); + } + + let zelf = &args.args[0]; + if !objtype::isinstance(&zelf, &vm.ctx.str_type()) { + let zelf_typ = zelf.typ(); + let actual_type = vm.to_pystr(&zelf_typ)?; + return Err(vm.new_type_error(format!( + "descriptor 'format' requires a 'str' object but received a '{}'", + actual_type + ))); + } + let format_string_text = get_value(zelf); + match FormatString::from_str(format_string_text.as_str()) { + Ok(format_string) => perform_format(vm, &format_string, &args), + Err(err) => match err { + FormatParseError::UnmatchedBracket => { + Err(vm.new_value_error("expected '}' before end of string".to_string())) + } + _ => Err(vm.new_value_error("Unexpected error parsing format string".to_string())), + }, + } +} + +fn call_object_format( + vm: &mut VirtualMachine, + argument: PyObjectRef, + format_spec: &String, +) -> PyResult { + let returned_type = vm.ctx.new_str(format_spec.clone()); + let result = vm.call_method(&argument, "__format__", vec![returned_type])?; + if !objtype::isinstance(&result, &vm.ctx.str_type()) { + let result_type = result.typ(); + let actual_type = vm.to_pystr(&result_type)?; + return Err(vm.new_type_error(format!("__format__ must return a str, not {}", actual_type))); + } + Ok(result) +} + +fn perform_format( + vm: &mut VirtualMachine, + format_string: &FormatString, + arguments: &PyFuncArgs, +) -> PyResult { + let mut final_string = String::new(); + if format_string.format_parts.iter().any(FormatPart::is_auto) + && format_string.format_parts.iter().any(FormatPart::is_index) + { + return Err(vm.new_value_error( + "cannot switch from automatic field numbering to manual field specification" + .to_string(), + )); + } + let mut auto_argument_index: usize = 1; + for part in &format_string.format_parts { + let result_string: String = match part { + FormatPart::AutoSpec(format_spec) => { + let result = match arguments.args.get(auto_argument_index) { + Some(argument) => call_object_format(vm, argument.clone(), &format_spec)?, + None => { + return Err(vm.new_index_error("tuple index out of range".to_string())); + } + }; + auto_argument_index += 1; + get_value(&result) + } + FormatPart::IndexSpec(index, format_spec) => { + let result = match arguments.args.get(*index + 1) { + Some(argument) => call_object_format(vm, argument.clone(), &format_spec)?, + None => { + return Err(vm.new_index_error("tuple index out of range".to_string())); + } + }; + get_value(&result) + } + FormatPart::KeywordSpec(keyword, format_spec) => { + let result = match arguments.get_optional_kwarg(&keyword) { + Some(argument) => call_object_format(vm, argument.clone(), &format_spec)?, + None => { + return Err(vm.new_key_error(format!("'{}'", keyword))); + } + }; + get_value(&result) + } + FormatPart::Literal(literal) => literal.clone(), + }; + final_string.push_str(&result_string); + } + Ok(vm.ctx.new_str(final_string)) +} + fn str_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.str_type()))]); let value = get_value(zelf); diff --git a/vm/src/vm.rs b/vm/src/vm.rs index c6902d4770..c2c7b561e7 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -107,6 +107,16 @@ impl VirtualMachine { self.new_exception(value_error, msg) } + pub fn new_key_error(&mut self, msg: String) -> PyObjectRef { + let key_error = self.ctx.exceptions.key_error.clone(); + self.new_exception(key_error, msg) + } + + pub fn new_index_error(&mut self, msg: String) -> PyObjectRef { + let index_error = self.ctx.exceptions.index_error.clone(); + self.new_exception(index_error, msg) + } + pub fn new_not_implemented_error(&mut self, msg: String) -> PyObjectRef { let value_error = self.ctx.exceptions.not_implemented_error.clone(); self.new_exception(value_error, msg)