diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 3ad4218610..1ab7a1fb6d 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -712,8 +712,7 @@ def test_denial_of_service_prevented_int_to_str(self): self.assertIn('conversion', str(err.exception)) self.assertLess(sw_fail_extra_huge.seconds, sw_convert.seconds/2) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.skip('TODO: RUSTPYTHON; flaky test') def test_denial_of_service_prevented_str_to_int(self): """Regression test: ensure we fail before performing O(N**2) work.""" maxdigits = sys.get_int_max_str_digits() @@ -761,8 +760,6 @@ def test_power_of_two_bases_unlimited(self): assert maxdigits < 100_000 self.int_class('1' * 100_000, base) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_underscores_ignored(self): maxdigits = sys.get_int_max_str_digits() diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index 4d83f0ec9d..f07f7d5533 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -127,8 +127,6 @@ def test_negative_index(self): d = self.json.JSONDecoder() self.assertRaises(ValueError, d.raw_decode, 'a'*42, -50000) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_limit_int(self): maxdigits = 5000 with support.adjust_int_max_str_digits(maxdigits): @@ -144,3 +142,8 @@ class TestCDecode(TestDecode, CTest): @unittest.expectedFailure def test_keys_reuse(self): return super().test_keys_reuse() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_limit_int(self): + return super().test_limit_int() diff --git a/common/src/int.rs b/common/src/int.rs index 9ec9e01498..ed09cc01a0 100644 --- a/common/src/int.rs +++ b/common/src/int.rs @@ -29,113 +29,119 @@ pub fn float_to_ratio(value: f64) -> Option<(BigInt, BigInt)> { }) } -pub fn bytes_to_int(lit: &[u8], mut base: u32) -> Option { +#[derive(Debug, Eq, PartialEq)] +pub enum BytesToIntError { + InvalidLiteral { base: u32 }, + InvalidBase, + DigitLimit { got: usize, limit: usize }, +} + +// https://github.com/python/cpython/blob/4e665351082c50018fb31d80db25b4693057393e/Objects/longobject.c#L2977 +// https://github.com/python/cpython/blob/4e665351082c50018fb31d80db25b4693057393e/Objects/longobject.c#L2884 +pub fn bytes_to_int( + buf: &[u8], + mut base: u32, + digit_limit: usize, +) -> Result { + if base != 0 && !(2..=36).contains(&base) { + return Err(BytesToIntError::InvalidBase); + } + + let mut buf = buf.trim_ascii(); + // split sign - let mut lit = lit.trim_ascii(); - let sign = match lit.first()? { - b'+' => Some(Sign::Plus), - b'-' => Some(Sign::Minus), + let sign = match buf.first() { + Some(b'+') => Some(Sign::Plus), + Some(b'-') => Some(Sign::Minus), + None => return Err(BytesToIntError::InvalidLiteral { base }), _ => None, }; + if sign.is_some() { - lit = &lit[1..]; + buf = &buf[1..]; } - // split radix - let first = *lit.first()?; - let has_radix = if first == b'0' { - match base { - 0 => { - if let Some(parsed) = lit.get(1).and_then(detect_base) { - base = parsed; - true - } else { - if let [_first, others @ .., last] = lit { - let is_zero = - others.iter().all(|&c| c == b'0' || c == b'_') && *last == b'0'; - if !is_zero { - return None; - } - } - return Some(BigInt::zero()); - } + let mut error_if_nonzero = false; + if base == 0 { + match (buf.first(), buf.get(1)) { + (Some(v), _) if *v != b'0' => base = 10, + (_, Some(b'x' | b'X')) => base = 16, + (_, Some(b'o' | b'O')) => base = 8, + (_, Some(b'b' | b'B')) => base = 2, + (_, _) => { + // "old" (C-style) octal literal, now invalid. it might still be zero though + base = 10; + error_if_nonzero = true; } - 16 => lit.get(1).is_some_and(|&b| matches!(b, b'x' | b'X')), - 2 => lit.get(1).is_some_and(|&b| matches!(b, b'b' | b'B')), - 8 => lit.get(1).is_some_and(|&b| matches!(b, b'o' | b'O')), - _ => false, - } - } else { - if base == 0 { - base = 10; - } - false - }; - if has_radix { - lit = &lit[2..]; - if lit.first()? == &b'_' { - lit = &lit[1..]; } } - // remove zeroes - let mut last = *lit.first()?; - if last == b'0' { - let mut count = 0; - for &cur in &lit[1..] { - if cur == b'_' { - if last == b'_' { - return None; - } - } else if cur != b'0' { - break; - }; - count += 1; - last = cur; - } - let prefix_last = lit[count]; - lit = &lit[count + 1..]; - if lit.is_empty() && prefix_last == b'_' { - return None; + if error_if_nonzero { + if let [_first, others @ .., last] = buf { + let is_zero = *last == b'0' && others.iter().all(|&c| c == b'0' || c == b'_'); + if !is_zero { + return Err(BytesToIntError::InvalidLiteral { base }); + } } + return Ok(BigInt::zero()); } - // validate - for c in lit { - let c = *c; - if !(c.is_ascii_alphanumeric() || c == b'_') { - return None; + if buf.first().is_some_and(|&v| v == b'0') + && buf.get(1).is_some_and(|&v| { + (base == 16 && (v == b'x' || v == b'X')) + || (base == 8 && (v == b'o' || v == b'O')) + || (base == 2 && (v == b'b' || v == b'B')) + }) + { + buf = &buf[2..]; + + // One underscore allowed here + if buf.first().is_some_and(|&v| v == b'_') { + buf = &buf[1..]; } + } - if c == b'_' && last == b'_' { - return None; + // Reject empty strings + let mut prev = *buf + .first() + .ok_or(BytesToIntError::InvalidLiteral { base })?; + + // Leading underscore not allowed + if prev == b'_' || !prev.is_ascii_alphanumeric() { + return Err(BytesToIntError::InvalidLiteral { base }); + } + + // Verify all characters are digits and underscores + let mut digits = 1; + for &cur in buf.iter().skip(1) { + if cur == b'_' { + // Double underscore not allowed + if prev == b'_' { + return Err(BytesToIntError::InvalidLiteral { base }); + } + } else if cur.is_ascii_alphanumeric() { + digits += 1; + } else { + return Err(BytesToIntError::InvalidLiteral { base }); } - last = c; + prev = cur; } - if last == b'_' { - return None; + + // Trailing underscore not allowed + if prev == b'_' { + return Err(BytesToIntError::InvalidLiteral { base }); } - // parse - let number = if lit.is_empty() { - BigInt::zero() - } else { - let uint = BigUint::parse_bytes(lit, base)?; - BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint) - }; - Some(number) -} + if digit_limit > 0 && !base.is_power_of_two() && digits > digit_limit { + return Err(BytesToIntError::DigitLimit { + got: digits, + limit: digit_limit, + }); + } -#[inline] -pub const fn detect_base(c: &u8) -> Option { - let base = match c { - b'x' | b'X' => 16, - b'b' | b'B' => 2, - b'o' | b'O' => 8, - _ => return None, - }; - Some(base) + let uint = BigUint::parse_bytes(buf, base).ok_or(BytesToIntError::InvalidLiteral { base })?; + Ok(BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint)) } // num-bigint now returns Some(inf) for to_f64() in some cases, so just keep that the same for now @@ -144,15 +150,59 @@ pub fn bigint_to_finite_float(int: &BigInt) -> Option { int.to_f64().filter(|f| f.is_finite()) } -#[test] -fn test_bytes_to_int() { - assert_eq!(bytes_to_int(&b"0b101"[..], 2).unwrap(), BigInt::from(5)); - assert_eq!(bytes_to_int(&b"0x_10"[..], 16).unwrap(), BigInt::from(16)); - assert_eq!(bytes_to_int(&b"0b"[..], 16).unwrap(), BigInt::from(11)); - assert_eq!(bytes_to_int(&b"+0b101"[..], 2).unwrap(), BigInt::from(5)); - assert_eq!(bytes_to_int(&b"0_0_0"[..], 10).unwrap(), BigInt::from(0)); - assert_eq!(bytes_to_int(&b"09_99"[..], 0), None); - assert_eq!(bytes_to_int(&b"000"[..], 0).unwrap(), BigInt::from(0)); - assert_eq!(bytes_to_int(&b"0_"[..], 0), None); - assert_eq!(bytes_to_int(&b"0_100"[..], 10).unwrap(), BigInt::from(100)); +#[cfg(test)] +mod tests { + use super::*; + + const DIGIT_LIMIT: usize = 4300; // Default of Cpython + + #[test] + fn bytes_to_int_valid() { + for ((buf, base), expected) in [ + (("0b101", 2), BigInt::from(5)), + (("0x_10", 16), BigInt::from(16)), + (("0b", 16), BigInt::from(11)), + (("+0b101", 2), BigInt::from(5)), + (("0_0_0", 10), BigInt::from(0)), + (("000", 0), BigInt::from(0)), + (("0_100", 10), BigInt::from(100)), + ] { + assert_eq!( + bytes_to_int(buf.as_bytes(), base, DIGIT_LIMIT), + Ok(expected) + ); + } + } + + #[test] + fn bytes_to_int_invalid_literal() { + for ((buf, base), expected) in [ + (("09_99", 0), BytesToIntError::InvalidLiteral { base: 10 }), + (("0_", 0), BytesToIntError::InvalidLiteral { base: 10 }), + (("0_", 2), BytesToIntError::InvalidLiteral { base: 2 }), + ] { + assert_eq!( + bytes_to_int(buf.as_bytes(), base, DIGIT_LIMIT), + Err(expected) + ) + } + } + + #[test] + fn bytes_to_int_invalid_base() { + for base in [1, 37] { + assert_eq!( + bytes_to_int("012345".as_bytes(), base, DIGIT_LIMIT), + Err(BytesToIntError::InvalidBase) + ) + } + } + + #[test] + fn bytes_to_int_digit_limit() { + assert_eq!( + bytes_to_int("012345".as_bytes(), 10, 5), + Err(BytesToIntError::DigitLimit { got: 6, limit: 5 }) + ); + } } diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index ebeb1638fd..f93d0f31f7 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -15,7 +15,7 @@ use crate::{ ArgByteOrder, ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue, PyComparisonValue, }, - protocol::PyNumberMethods, + protocol::{PyNumberMethods, handle_bytes_to_int_err}, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }; use malachite_bigint::{BigInt, Sign}; @@ -829,33 +829,23 @@ struct IntToByteArgs { } fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult { - debug_assert!(base == 0 || (2..=36).contains(&base)); - - let opt = match_class!(match obj.to_owned() { + match_class!(match obj.to_owned() { string @ PyStr => { let s = string.as_wtf8().trim(); - bytes_to_int(s.as_bytes(), base) + bytes_to_int(s.as_bytes(), base, vm.state.int_max_str_digits.load()) + .map_err(|e| handle_bytes_to_int_err(e, obj, vm)) } bytes @ PyBytes => { - let bytes = bytes.as_bytes(); - bytes_to_int(bytes, base) + bytes_to_int(bytes.as_bytes(), base, vm.state.int_max_str_digits.load()) + .map_err(|e| handle_bytes_to_int_err(e, obj, vm)) } bytearray @ PyByteArray => { let inner = bytearray.borrow_buf(); - bytes_to_int(&inner, base) - } - _ => { - return Err(vm.new_type_error("int() can't convert non-string with explicit base")); + bytes_to_int(&inner, base, vm.state.int_max_str_digits.load()) + .map_err(|e| handle_bytes_to_int_err(e, obj, vm)) } - }); - match opt { - Some(int) => Ok(int), - None => Err(vm.new_value_error(format!( - "invalid literal for int() with base {}: {}", - base, - obj.repr(vm)?, - ))), - } + _ => Err(vm.new_type_error("int() can't convert non-string with explicit base")), + }) } // Retrieve inner int value: diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index 989cca73d8..d5c7e239a2 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -12,6 +12,6 @@ pub use iter::{PyIter, PyIterIter, PyIterReturn}; pub use mapping::{PyMapping, PyMappingMethods}; pub use number::{ PyNumber, PyNumberBinaryFunc, PyNumberBinaryOp, PyNumberMethods, PyNumberSlots, - PyNumberTernaryOp, PyNumberUnaryFunc, + PyNumberTernaryOp, PyNumberUnaryFunc, handle_bytes_to_int_err, }; pub use sequence::{PySequence, PySequenceMethods}; diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index b103fdddd6..671877057d 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -5,8 +5,10 @@ use crossbeam_utils::atomic::AtomicCell; use crate::{ AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, - builtins::{PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr, int}, - common::int::bytes_to_int, + builtins::{ + PyBaseExceptionRef, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr, int, + }, + common::int::{BytesToIntError, bytes_to_int}, function::ArgBytesLike, object::{Traverse, TraverseFn}, stdlib::warnings, @@ -45,15 +47,10 @@ impl PyObject { pub fn try_int(&self, vm: &VirtualMachine) -> PyResult { fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { let base = 10; - let i = bytes_to_int(lit, base).ok_or_else(|| { - let repr = match obj.repr(vm) { - Ok(repr) => repr, - Err(err) => return err, - }; - vm.new_value_error(format!( - "invalid literal for int() with base {base}: {repr}", - )) - })?; + let digit_limit = vm.state.int_max_str_digits.load(); + + let i = bytes_to_int(lit, base, digit_limit) + .map_err(|e| handle_bytes_to_int_err(e, obj, vm))?; Ok(PyInt::from(i).into_ref(&vm.ctx)) } @@ -559,3 +556,25 @@ impl PyNumber<'_> { }) } } + +pub fn handle_bytes_to_int_err( + e: BytesToIntError, + obj: &PyObject, + vm: &VirtualMachine, +) -> PyBaseExceptionRef { + match e { + BytesToIntError::InvalidLiteral { base } => vm.new_value_error(format!( + "invalid literal for int() with base {base}: {}", + match obj.repr(vm) { + Ok(v) => v, + Err(err) => return err, + }, + )), + BytesToIntError::InvalidBase => { + vm.new_value_error("int() base must be >= 2 and <= 36, or 0") + } + BytesToIntError::DigitLimit { got, limit } => vm.new_value_error(format!( +"Exceeds the limit ({limit} digits) for integer string conversion: value has {got} digits; use sys.set_int_max_str_digits() to increase the limit" + )), + } +}