From 3f3fa2e5af79cac17b4b8b55b34546299600c2da Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Fri, 16 Aug 2019 21:36:05 +0200 Subject: [PATCH 1/4] Fix panics with int() + add automatic radix detection based on given literal if optional arg base is set to 0 (CPython behavior) + tolerate underscore separators --- vm/src/obj/objint.rs | 73 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 478671257c..a651d4b914 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -1,8 +1,9 @@ use std::fmt; +use std::str; use num_bigint::{BigInt, Sign}; use num_integer::Integer; -use num_traits::{One, Pow, Signed, ToPrimitive, Zero}; +use num_traits::{Num, One, Pow, Signed, ToPrimitive, Zero}; use crate::format::FormatSpec; use crate::function::{KwArgs, OptionalArg, PyFuncArgs}; @@ -713,7 +714,9 @@ impl IntOptions { fn get_int_value(self, vm: &VirtualMachine) -> PyResult { if let OptionalArg::Present(val) = self.val_options { let base = if let OptionalArg::Present(base) = self.base { - if !objtype::isinstance(&val, &vm.ctx.str_type()) { + if !(objtype::isinstance(&val, &vm.ctx.str_type()) + || objtype::isinstance(&val, &vm.ctx.bytes_type())) + { return Err(vm.new_type_error( "int() can't convert non-string with explicit base".to_string(), )); @@ -736,21 +739,22 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul } // Casting function: -pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult { - if base == 0 { - base = 10 - } else if base < 2 || base > 36 { +pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult { + if base != 0 && (base < 2 || base > 36) { return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_string())); } match_class!(obj.clone(), - s @ PyString => { - i32::from_str_radix(s.as_str().trim(), base) - .map(BigInt::from) - .map_err(|_|vm.new_value_error(format!( - "invalid literal for int() with base {}: '{}'", - base, s - ))) + string @ PyString => { + let s = string.value.as_str().trim(); + str_to_int(vm, s, base) + }, + bytes @ PyBytes => { + let bytes = bytes.get_value(); + let s = std::str::from_utf8(bytes) + .map(|s| s.trim()) + .map_err(|_| invalid_literal(vm, &bytes.iter().map(|&c| c as char).collect::(), base))?; + str_to_int(vm, s, base) }, obj => { let method = vm.get_method_or_type_error(obj.clone(), "__int__", || { @@ -766,6 +770,49 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult ) } +fn str_to_int(vm: &VirtualMachine, literal: &str, mut base: u32) -> PyResult { + let mut buf = literal.chars().filter(|&c| c != '_').collect::(); + + let is_signed = buf.starts_with('+') || buf.starts_with('-'); + let radix_range = if is_signed { 1..3 } else { 0..2 }; + let radix_candidate = buf.get(radix_range.clone()); + + // try to find base + if let Some(radix_candidate) = radix_candidate { + if let Some(matched_radix) = detect_base(&radix_candidate) { + if base != 0 && base != matched_radix { + return Err(invalid_literal(vm, literal, base)); + } else { + base = matched_radix; + } + buf.drain(radix_range); + } + } + + // base still not found, use default + if base == 0 { + base = 10; + } + + BigInt::from_str_radix(&buf, base).map_err(|_err| invalid_literal(vm, literal, base)) +} + +fn detect_base(literal: &str) -> Option { + match literal { + "0x" | "0X" => Some(16), + "0o" | "0O" => Some(8), + "0b" | "0B" => Some(2), + _ => None, + } +} + +fn invalid_literal(vm: &VirtualMachine, literal: &str, base: u32) -> PyObjectRef { + vm.new_value_error(format!( + "invalid literal for int() with base {}: '{}'", + base, literal + )) +} + // Retrieve inner int value: pub fn get_value(obj: &PyObjectRef) -> &BigInt { &get_py_int(obj).value From d07edee3f696fb89f480f9af9ae8dc800297c5c0 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Fri, 16 Aug 2019 21:39:50 +0200 Subject: [PATCH 2/4] Add more tests for int() --- tests/snippets/ints.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/snippets/ints.py b/tests/snippets/ints.py index 0e89d44f32..ab3d646a26 100644 --- a/tests/snippets/ints.py +++ b/tests/snippets/ints.py @@ -97,14 +97,40 @@ assert -10 // -4 == 2 assert int() == 0 -assert int("101", 2) == 5 -assert int("101", base=2) == 5 assert int(1) == 1 + +# implied base +assert int('1', base=0) == 1 +assert int('123', base=0) == 123 +assert int('0b101', base=0) == 5 +assert int('0B101', base=0) == 5 +assert int('0o100', base=0) == 64 +assert int('0O100', base=0) == 64 +assert int('0xFF', base=0) == 255 +assert int('0XFF', base=0) == 255 +with assertRaises(ValueError): + int('0xFF', base=10) +with assertRaises(ValueError): + int('0oFF', base=10) +with assertRaises(ValueError): + int('0bFF', base=10) + +# underscore +assert int('0xFF_FF_FF', base=16) == 16_777_215 + +# signed +assert int('-123') == -123 +assert int('+0b101', base=2) == +5 + +# trailing spaces assert int(' 1') == 1 assert int('1 ') == 1 assert int(' 1 ') == 1 assert int('10', base=0) == 10 +# type byte, signed, implied base +assert int(b' -0XFF ', base=0) == -255 + assert int.from_bytes(b'\x00\x10', 'big') == 16 assert int.from_bytes(b'\x00\x10', 'little') == 4096 assert int.from_bytes(b'\xfc\x00', 'big', signed=True) == -1024 @@ -179,4 +205,4 @@ def __int__(self): assert_raises(TypeError, lambda: (0).__round__(None)) assert_raises(TypeError, lambda: (1).__round__(None)) assert_raises(TypeError, lambda: (0).__round__(0.0)) -assert_raises(TypeError, lambda: (1).__round__(0.0)) \ No newline at end of file +assert_raises(TypeError, lambda: (1).__round__(0.0)) From 5e37adcf681192b6bd4a4f73a0d1d11a51d945b1 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Sat, 17 Aug 2019 16:19:02 +0200 Subject: [PATCH 3/4] Apply review comments --- tests/snippets/ints.py | 6 ++++++ vm/src/obj/objint.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/snippets/ints.py b/tests/snippets/ints.py index ab3d646a26..23a1c9a923 100644 --- a/tests/snippets/ints.py +++ b/tests/snippets/ints.py @@ -98,6 +98,8 @@ assert int() == 0 assert int(1) == 1 +assert int("101", 2) == 5 +assert int("101", base=2) == 5 # implied base assert int('1', base=0) == 1 @@ -114,6 +116,10 @@ int('0oFF', base=10) with assertRaises(ValueError): int('0bFF', base=10) +with assertRaises(ValueError): + int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r") +with assertRaises(ValueError): + int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r") # underscore assert int('0xFF_FF_FF', base=16) == 16_777_215 diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index a651d4b914..02d1ab1955 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -753,7 +753,7 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult(), base))?; + .map_err(|e| vm.new_value_error(format!("utf8 decode error: {}", e)))?; str_to_int(vm, s, base) }, obj => { From 6d618c5e5a07833d8ff1ff91425666a2d24417e0 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Mon, 19 Aug 2019 23:47:10 +0200 Subject: [PATCH 4/4] Add changes suggested by @seeturtle + Add tests covering those changes --- tests/snippets/ints.py | 11 +++++++++++ vm/src/obj/objint.rs | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tests/snippets/ints.py b/tests/snippets/ints.py index 23a1c9a923..4bb1a51272 100644 --- a/tests/snippets/ints.py +++ b/tests/snippets/ints.py @@ -116,6 +116,8 @@ int('0oFF', base=10) with assertRaises(ValueError): int('0bFF', base=10) +with assertRaises(ValueError): + int('0bFF', base=10) with assertRaises(ValueError): int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r") with assertRaises(ValueError): @@ -123,6 +125,14 @@ # underscore assert int('0xFF_FF_FF', base=16) == 16_777_215 +with assertRaises(ValueError): + int("_123_") +with assertRaises(ValueError): + int("123_") +with assertRaises(ValueError): + int("_123") +with assertRaises(ValueError): + int("1__23") # signed assert int('-123') == -123 @@ -137,6 +147,7 @@ # type byte, signed, implied base assert int(b' -0XFF ', base=0) == -255 + assert int.from_bytes(b'\x00\x10', 'big') == 16 assert int.from_bytes(b'\x00\x10', 'little') == 4096 assert int.from_bytes(b'\xfc\x00', 'big', signed=True) == -1024 diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 02d1ab1955..d8bcd380f0 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -771,8 +771,7 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult PyResult { - let mut buf = literal.chars().filter(|&c| c != '_').collect::(); - + let mut buf = validate_literal(vm, literal, base)?; let is_signed = buf.starts_with('+') || buf.starts_with('-'); let radix_range = if is_signed { 1..3 } else { 0..2 }; let radix_candidate = buf.get(radix_range.clone()); @@ -785,18 +784,46 @@ fn str_to_int(vm: &VirtualMachine, literal: &str, mut base: u32) -> PyResult PyResult { + if literal.starts_with('_') || literal.ends_with('_') { + return Err(invalid_literal(vm, literal, base)); + } + + let mut buf = String::with_capacity(literal.len()); + let mut last_tok = None; + for c in literal.chars() { + if !(c.is_ascii_alphanumeric() || c == '_' || c == '+' || c == '-') { + return Err(invalid_literal(vm, literal, base)); + } + + if c == '_' && Some(c) == last_tok { + return Err(invalid_literal(vm, literal, base)); + } + + last_tok = Some(c); + buf.push(c); + } + + Ok(buf) +} + fn detect_base(literal: &str) -> Option { match literal { "0x" | "0X" => Some(16),