Skip to content

Fix int respect sys.set_int_max_str_digits value #6094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Lib/test/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,7 @@ def test_denial_of_service_prevented_int_to_str(self):
self.assertIn('conversion', str(err.exception))
self.assertLess(sw_fail_extra_huge.seconds, sw_convert.seconds/2)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skip('TODO: RUSTPYTHON; flaky test')
def test_denial_of_service_prevented_str_to_int(self):
"""Regression test: ensure we fail before performing O(N**2) work."""
maxdigits = sys.get_int_max_str_digits()
Expand Down Expand Up @@ -761,8 +760,6 @@ def test_power_of_two_bases_unlimited(self):
assert maxdigits < 100_000
self.int_class('1' * 100_000, base)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_underscores_ignored(self):
maxdigits = sys.get_int_max_str_digits()

Expand Down
7 changes: 5 additions & 2 deletions Lib/test/test_json/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ def test_negative_index(self):
d = self.json.JSONDecoder()
self.assertRaises(ValueError, d.raw_decode, 'a'*42, -50000)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_limit_int(self):
maxdigits = 5000
with support.adjust_int_max_str_digits(maxdigits):
Expand All @@ -144,3 +142,8 @@ class TestCDecode(TestDecode, CTest):
@unittest.expectedFailure
def test_keys_reuse(self):
return super().test_keys_reuse()

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_limit_int(self):
return super().test_limit_int()
244 changes: 147 additions & 97 deletions common/src/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,113 +29,119 @@ pub fn float_to_ratio(value: f64) -> Option<(BigInt, BigInt)> {
})
}

pub fn bytes_to_int(lit: &[u8], mut base: u32) -> Option<BigInt> {
#[derive(Debug, Eq, PartialEq)]
pub enum BytesToIntError {
InvalidLiteral { base: u32 },
InvalidBase,
DigitLimit { got: usize, limit: usize },
}

// https://github.com/python/cpython/blob/4e665351082c50018fb31d80db25b4693057393e/Objects/longobject.c#L2977
// https://github.com/python/cpython/blob/4e665351082c50018fb31d80db25b4693057393e/Objects/longobject.c#L2884
pub fn bytes_to_int(
buf: &[u8],
mut base: u32,
digit_limit: usize,
) -> Result<BigInt, BytesToIntError> {
if base != 0 && !(2..=36).contains(&base) {
return Err(BytesToIntError::InvalidBase);
}

let mut buf = buf.trim_ascii();

// split sign
let mut lit = lit.trim_ascii();
let sign = match lit.first()? {
b'+' => Some(Sign::Plus),
b'-' => Some(Sign::Minus),
let sign = match buf.first() {
Some(b'+') => Some(Sign::Plus),
Some(b'-') => Some(Sign::Minus),
None => return Err(BytesToIntError::InvalidLiteral { base }),
_ => None,
};

if sign.is_some() {
lit = &lit[1..];
buf = &buf[1..];
}

// split radix
let first = *lit.first()?;
let has_radix = if first == b'0' {
match base {
0 => {
if let Some(parsed) = lit.get(1).and_then(detect_base) {
base = parsed;
true
} else {
if let [_first, others @ .., last] = lit {
let is_zero =
others.iter().all(|&c| c == b'0' || c == b'_') && *last == b'0';
if !is_zero {
return None;
}
}
return Some(BigInt::zero());
}
let mut error_if_nonzero = false;
if base == 0 {
match (buf.first(), buf.get(1)) {
(Some(v), _) if *v != b'0' => base = 10,
(_, Some(b'x' | b'X')) => base = 16,
(_, Some(b'o' | b'O')) => base = 8,
(_, Some(b'b' | b'B')) => base = 2,
(_, _) => {
// "old" (C-style) octal literal, now invalid. it might still be zero though
base = 10;
error_if_nonzero = true;
}
16 => lit.get(1).is_some_and(|&b| matches!(b, b'x' | b'X')),
2 => lit.get(1).is_some_and(|&b| matches!(b, b'b' | b'B')),
8 => lit.get(1).is_some_and(|&b| matches!(b, b'o' | b'O')),
_ => false,
}
} else {
if base == 0 {
base = 10;
}
false
};
if has_radix {
lit = &lit[2..];
if lit.first()? == &b'_' {
lit = &lit[1..];
}
}

// remove zeroes
let mut last = *lit.first()?;
if last == b'0' {
let mut count = 0;
for &cur in &lit[1..] {
if cur == b'_' {
if last == b'_' {
return None;
}
} else if cur != b'0' {
break;
};
count += 1;
last = cur;
}
let prefix_last = lit[count];
lit = &lit[count + 1..];
if lit.is_empty() && prefix_last == b'_' {
return None;
if error_if_nonzero {
if let [_first, others @ .., last] = buf {
let is_zero = *last == b'0' && others.iter().all(|&c| c == b'0' || c == b'_');
if !is_zero {
return Err(BytesToIntError::InvalidLiteral { base });
}
}
return Ok(BigInt::zero());
}

// validate
for c in lit {
let c = *c;
if !(c.is_ascii_alphanumeric() || c == b'_') {
return None;
if buf.first().is_some_and(|&v| v == b'0')
&& buf.get(1).is_some_and(|&v| {
(base == 16 && (v == b'x' || v == b'X'))
|| (base == 8 && (v == b'o' || v == b'O'))
|| (base == 2 && (v == b'b' || v == b'B'))
})
{
buf = &buf[2..];

// One underscore allowed here
if buf.first().is_some_and(|&v| v == b'_') {
buf = &buf[1..];
}
}

if c == b'_' && last == b'_' {
return None;
// Reject empty strings
let mut prev = *buf
.first()
.ok_or(BytesToIntError::InvalidLiteral { base })?;

// Leading underscore not allowed
if prev == b'_' || !prev.is_ascii_alphanumeric() {
return Err(BytesToIntError::InvalidLiteral { base });
}

// Verify all characters are digits and underscores
let mut digits = 1;
for &cur in buf.iter().skip(1) {
if cur == b'_' {
// Double underscore not allowed
if prev == b'_' {
return Err(BytesToIntError::InvalidLiteral { base });
}
} else if cur.is_ascii_alphanumeric() {
digits += 1;
} else {
return Err(BytesToIntError::InvalidLiteral { base });
}

last = c;
prev = cur;
}
if last == b'_' {
return None;

// Trailing underscore not allowed
if prev == b'_' {
return Err(BytesToIntError::InvalidLiteral { base });
}

// parse
let number = if lit.is_empty() {
BigInt::zero()
} else {
let uint = BigUint::parse_bytes(lit, base)?;
BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint)
};
Some(number)
}
if digit_limit > 0 && !base.is_power_of_two() && digits > digit_limit {
return Err(BytesToIntError::DigitLimit {
got: digits,
limit: digit_limit,
});
}

#[inline]
pub const fn detect_base(c: &u8) -> Option<u32> {
let base = match c {
b'x' | b'X' => 16,
b'b' | b'B' => 2,
b'o' | b'O' => 8,
_ => return None,
};
Some(base)
let uint = BigUint::parse_bytes(buf, base).ok_or(BytesToIntError::InvalidLiteral { base })?;
Ok(BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint))
}

// num-bigint now returns Some(inf) for to_f64() in some cases, so just keep that the same for now
Expand All @@ -144,15 +150,59 @@ pub fn bigint_to_finite_float(int: &BigInt) -> Option<f64> {
int.to_f64().filter(|f| f.is_finite())
}

#[test]
fn test_bytes_to_int() {
assert_eq!(bytes_to_int(&b"0b101"[..], 2).unwrap(), BigInt::from(5));
assert_eq!(bytes_to_int(&b"0x_10"[..], 16).unwrap(), BigInt::from(16));
assert_eq!(bytes_to_int(&b"0b"[..], 16).unwrap(), BigInt::from(11));
assert_eq!(bytes_to_int(&b"+0b101"[..], 2).unwrap(), BigInt::from(5));
assert_eq!(bytes_to_int(&b"0_0_0"[..], 10).unwrap(), BigInt::from(0));
assert_eq!(bytes_to_int(&b"09_99"[..], 0), None);
assert_eq!(bytes_to_int(&b"000"[..], 0).unwrap(), BigInt::from(0));
assert_eq!(bytes_to_int(&b"0_"[..], 0), None);
assert_eq!(bytes_to_int(&b"0_100"[..], 10).unwrap(), BigInt::from(100));
#[cfg(test)]
mod tests {
use super::*;

const DIGIT_LIMIT: usize = 4300; // Default of Cpython

#[test]
fn bytes_to_int_valid() {
for ((buf, base), expected) in [
(("0b101", 2), BigInt::from(5)),
(("0x_10", 16), BigInt::from(16)),
(("0b", 16), BigInt::from(11)),
(("+0b101", 2), BigInt::from(5)),
(("0_0_0", 10), BigInt::from(0)),
(("000", 0), BigInt::from(0)),
(("0_100", 10), BigInt::from(100)),
] {
assert_eq!(
bytes_to_int(buf.as_bytes(), base, DIGIT_LIMIT),
Ok(expected)
);
}
}

#[test]
fn bytes_to_int_invalid_literal() {
for ((buf, base), expected) in [
(("09_99", 0), BytesToIntError::InvalidLiteral { base: 10 }),
(("0_", 0), BytesToIntError::InvalidLiteral { base: 10 }),
(("0_", 2), BytesToIntError::InvalidLiteral { base: 2 }),
] {
assert_eq!(
bytes_to_int(buf.as_bytes(), base, DIGIT_LIMIT),
Err(expected)
)
}
}

#[test]
fn bytes_to_int_invalid_base() {
for base in [1, 37] {
assert_eq!(
bytes_to_int("012345".as_bytes(), base, DIGIT_LIMIT),
Err(BytesToIntError::InvalidBase)
)
}
}

#[test]
fn bytes_to_int_digit_limit() {
assert_eq!(
bytes_to_int("012345".as_bytes(), 10, 5),
Err(BytesToIntError::DigitLimit { got: 6, limit: 5 })
);
}
}
30 changes: 10 additions & 20 deletions vm/src/builtins/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
ArgByteOrder, ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue,
PyComparisonValue,
},
protocol::PyNumberMethods,
protocol::{PyNumberMethods, handle_bytes_to_int_err},
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
};
use malachite_bigint::{BigInt, Sign};
Expand Down Expand Up @@ -829,33 +829,23 @@ struct IntToByteArgs {
}

fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult<BigInt> {
debug_assert!(base == 0 || (2..=36).contains(&base));

let opt = match_class!(match obj.to_owned() {
match_class!(match obj.to_owned() {
string @ PyStr => {
let s = string.as_wtf8().trim();
bytes_to_int(s.as_bytes(), base)
bytes_to_int(s.as_bytes(), base, vm.state.int_max_str_digits.load())
.map_err(|e| handle_bytes_to_int_err(e, obj, vm))
}
bytes @ PyBytes => {
let bytes = bytes.as_bytes();
bytes_to_int(bytes, base)
bytes_to_int(bytes.as_bytes(), base, vm.state.int_max_str_digits.load())
.map_err(|e| handle_bytes_to_int_err(e, obj, vm))
}
bytearray @ PyByteArray => {
let inner = bytearray.borrow_buf();
bytes_to_int(&inner, base)
}
_ => {
return Err(vm.new_type_error("int() can't convert non-string with explicit base"));
bytes_to_int(&inner, base, vm.state.int_max_str_digits.load())
.map_err(|e| handle_bytes_to_int_err(e, obj, vm))
}
});
match opt {
Some(int) => Ok(int),
None => Err(vm.new_value_error(format!(
"invalid literal for int() with base {}: {}",
base,
obj.repr(vm)?,
))),
}
_ => Err(vm.new_type_error("int() can't convert non-string with explicit base")),
})
}

// Retrieve inner int value:
Expand Down
2 changes: 1 addition & 1 deletion vm/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ pub use iter::{PyIter, PyIterIter, PyIterReturn};
pub use mapping::{PyMapping, PyMappingMethods};
pub use number::{
PyNumber, PyNumberBinaryFunc, PyNumberBinaryOp, PyNumberMethods, PyNumberSlots,
PyNumberTernaryOp, PyNumberUnaryFunc,
PyNumberTernaryOp, PyNumberUnaryFunc, handle_bytes_to_int_err,
};
pub use sequence::{PySequence, PySequenceMethods};
Loading
Loading