Skip to content

Fix panics with int() #1290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions tests/snippets/ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,57 @@
assert -10 // -4 == 2

assert int() == 0
assert int(1) == 1
assert int("101", 2) == 5
assert int("101", base=2) == 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its good to keep this check with an explicit base given.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll fix it in a moment

assert int(1) == 1

# implied base
assert int('1', base=0) == 1
assert int('123', base=0) == 123
assert int('0b101', base=0) == 5
assert int('0B101', base=0) == 5
assert int('0o100', base=0) == 64
assert int('0O100', base=0) == 64
assert int('0xFF', base=0) == 255
assert int('0XFF', base=0) == 255
with assertRaises(ValueError):
int('0xFF', base=10)
with assertRaises(ValueError):
int('0oFF', base=10)
with assertRaises(ValueError):
int('0bFF', base=10)
with assertRaises(ValueError):
int('0bFF', base=10)
with assertRaises(ValueError):
int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r")
with assertRaises(ValueError):
int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r")

# underscore
assert int('0xFF_FF_FF', base=16) == 16_777_215
with assertRaises(ValueError):
int("_123_")
with assertRaises(ValueError):
int("123_")
with assertRaises(ValueError):
int("_123")
with assertRaises(ValueError):
int("1__23")

# signed
assert int('-123') == -123
assert int('+0b101', base=2) == +5

# trailing spaces
assert int(' 1') == 1
assert int('1 ') == 1
assert int(' 1 ') == 1
assert int('10', base=0) == 10

# type byte, signed, implied base
assert int(b' -0XFF ', base=0) == -255


assert int.from_bytes(b'\x00\x10', 'big') == 16
assert int.from_bytes(b'\x00\x10', 'little') == 4096
assert int.from_bytes(b'\xfc\x00', 'big', signed=True) == -1024
Expand Down Expand Up @@ -179,4 +222,4 @@ def __int__(self):
assert_raises(TypeError, lambda: (0).__round__(None))
assert_raises(TypeError, lambda: (1).__round__(None))
assert_raises(TypeError, lambda: (0).__round__(0.0))
assert_raises(TypeError, lambda: (1).__round__(0.0))
assert_raises(TypeError, lambda: (1).__round__(0.0))
100 changes: 87 additions & 13 deletions vm/src/obj/objint.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::fmt;
use std::str;

use num_bigint::{BigInt, Sign};
use num_integer::Integer;
use num_traits::{One, Pow, Signed, ToPrimitive, Zero};
use num_traits::{Num, One, Pow, Signed, ToPrimitive, Zero};

use crate::format::FormatSpec;
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
Expand Down Expand Up @@ -713,7 +714,9 @@ impl IntOptions {
fn get_int_value(self, vm: &VirtualMachine) -> PyResult<BigInt> {
if let OptionalArg::Present(val) = self.val_options {
let base = if let OptionalArg::Present(base) = self.base {
if !objtype::isinstance(&val, &vm.ctx.str_type()) {
if !(objtype::isinstance(&val, &vm.ctx.str_type())
|| objtype::isinstance(&val, &vm.ctx.bytes_type()))
{
return Err(vm.new_type_error(
"int() can't convert non-string with explicit base".to_string(),
));
Expand All @@ -736,21 +739,22 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul
}

// Casting function:
pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult<BigInt> {
if base == 0 {
base = 10
} else if base < 2 || base > 36 {
pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult<BigInt> {
if base != 0 && (base < 2 || base > 36) {
return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_string()));
}

match_class!(obj.clone(),
s @ PyString => {
i32::from_str_radix(s.as_str().trim(), base)
.map(BigInt::from)
.map_err(|_|vm.new_value_error(format!(
"invalid literal for int() with base {}: '{}'",
base, s
)))
string @ PyString => {
let s = string.value.as_str().trim();
str_to_int(vm, s, base)
},
bytes @ PyBytes => {
let bytes = bytes.get_value();
let s = std::str::from_utf8(bytes)
.map(|s| s.trim())
.map_err(|e| vm.new_value_error(format!("utf8 decode error: {}", e)))?;
str_to_int(vm, s, base)
},
obj => {
let method = vm.get_method_or_type_error(obj.clone(), "__int__", || {
Expand All @@ -766,6 +770,76 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult
)
}

fn str_to_int(vm: &VirtualMachine, literal: &str, mut base: u32) -> PyResult<BigInt> {
let mut buf = validate_literal(vm, literal, base)?;
let is_signed = buf.starts_with('+') || buf.starts_with('-');
let radix_range = if is_signed { 1..3 } else { 0..2 };
let radix_candidate = buf.get(radix_range.clone());

// try to find base
if let Some(radix_candidate) = radix_candidate {
if let Some(matched_radix) = detect_base(&radix_candidate) {
if base != 0 && base != matched_radix {
return Err(invalid_literal(vm, literal, base));
} else {
base = matched_radix;
}

buf.drain(radix_range);
}
}

// base still not found, try to use default
if base == 0 {
if buf.starts_with('0') {
return Err(invalid_literal(vm, literal, base));
}

base = 10;
}

BigInt::from_str_radix(&buf, base).map_err(|_err| invalid_literal(vm, literal, base))
}

fn validate_literal(vm: &VirtualMachine, literal: &str, base: u32) -> PyResult<String> {
if literal.starts_with('_') || literal.ends_with('_') {
return Err(invalid_literal(vm, literal, base));
}

let mut buf = String::with_capacity(literal.len());
let mut last_tok = None;
for c in literal.chars() {
if !(c.is_ascii_alphanumeric() || c == '_' || c == '+' || c == '-') {
return Err(invalid_literal(vm, literal, base));
}

if c == '_' && Some(c) == last_tok {
return Err(invalid_literal(vm, literal, base));
}

last_tok = Some(c);
buf.push(c);
}

Ok(buf)
}

fn detect_base(literal: &str) -> Option<u32> {
match literal {
"0x" | "0X" => Some(16),
"0o" | "0O" => Some(8),
"0b" | "0B" => Some(2),
_ => None,
}
}

fn invalid_literal(vm: &VirtualMachine, literal: &str, base: u32) -> PyObjectRef {
vm.new_value_error(format!(
"invalid literal for int() with base {}: '{}'",
base, literal
))
}

// Retrieve inner int value:
pub fn get_value(obj: &PyObjectRef) -> &BigInt {
&get_py_int(obj).value
Expand Down