diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 4879b28c09..d3a0452f4f 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -2,8 +2,8 @@ import unittest from test import support -# from test.test_grammar import (VALID_UNDERSCORE_LITERALS, -# INVALID_UNDERSCORE_LITERALS) +from test.test_grammar import (VALID_UNDERSCORE_LITERALS, + INVALID_UNDERSCORE_LITERALS) L = [ ('0', 0), @@ -31,7 +31,6 @@ class IntSubclass(int): class IntTestCases(unittest.TestCase): - @unittest.skip("TODO: RUSTPYTHON") def test_basic(self): self.assertEqual(int(314), 314) self.assertEqual(int(3.14), 3) @@ -215,7 +214,6 @@ def test_basic(self): self.assertEqual(int('2br45qc', 35), 4294967297) self.assertEqual(int('1z141z5', 36), 4294967297) - @unittest.skip("TODO: RUSTPYTHON") def test_underscores(self): for lit in VALID_UNDERSCORE_LITERALS: if any(ch in lit for ch in '.eEjJ'): @@ -481,8 +479,6 @@ def __trunc__(self): self.assertEqual(n, 1) self.assertIs(type(n), IntSubclass) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_error_message(self): def check(s, base=None): with self.assertRaises(ValueError, diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index a8aca1bbdc..746df5c7b4 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -13,7 +13,6 @@ use super::objfloat; use super::objmemory::PyMemoryView; use super::objstr::{PyString, PyStringRef}; use super::objtype::{self, PyClassRef}; -use crate::exceptions::PyBaseExceptionRef; use crate::format::FormatSpec; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyhash; @@ -724,26 +723,21 @@ struct IntToByteArgs { // Casting function: pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: &BigInt) -> PyResult { - let base_u32 = match base.to_u32() { - Some(base_u32) => base_u32, - None => { - return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned())) - } + let base = match base.to_u32() { + Some(base) if base == 0 || (2..=36).contains(&base) => base, + _ => return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned())), }; - if base_u32 != 0 && (base_u32 < 2 || base_u32 > 36) { - return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned())); - } let bytes_to_int = |bytes: &[u8]| { - let s = std::str::from_utf8(bytes) - .map_err(|e| vm.new_value_error(format!("utf8 decode error: {}", e)))?; - str_to_int(vm, s, base) + std::str::from_utf8(bytes) + .ok() + .and_then(|s| str_to_int(s, base)) }; - match_class!(match obj.clone() { + let opt = match_class!(match obj.clone() { string @ PyString => { let s = string.as_str(); - str_to_int(vm, &s, base) + str_to_int(&s, base) } bytes @ PyBytes => { let bytes = bytes.get_value(); @@ -770,36 +764,39 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: &BigInt) -> PyResult ) })?; let result = vm.invoke(&method, PyFuncArgs::default())?; - match result.payload::() { + return match result.payload::() { Some(int_obj) => Ok(int_obj.as_bigint().clone()), None => Err(vm.new_type_error(format!( "TypeError: __int__ returned non-int (type '{}')", result.class().name ))), - } + }; } - }) + }); + match opt { + Some(int) => Ok(int), + None => Err(vm.new_value_error(format!( + "invalid literal for int() with base {}: {}", + base, + vm.to_repr(obj)?, + ))), + } } -fn str_to_int(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult { - let mut buf = validate_literal(vm, literal, base)?; +fn str_to_int(literal: &str, mut base: u32) -> Option { + let mut buf = validate_literal(literal)?.to_owned(); let is_signed = buf.starts_with('+') || buf.starts_with('-'); let radix_range = if is_signed { 1..3 } else { 0..2 }; let radix_candidate = buf.get(radix_range.clone()); - let mut base_u32 = match base.to_u32() { - Some(base_u32) => base_u32, - None => return Err(invalid_literal(vm, literal, base)), - }; - // try to find base if let Some(radix_candidate) = radix_candidate { if let Some(matched_radix) = detect_base(&radix_candidate) { - if base_u32 == 0 || base_u32 == matched_radix { + if base == 0 || base == matched_radix { /* If base is 0 or equal radix number, it means radix is validate * So change base to radix number and remove radix from literal */ - base_u32 = matched_radix; + base = matched_radix; buf.drain(radix_range); /* first underscore with radix is validate @@ -808,49 +805,50 @@ fn str_to_int(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult PyResult { +fn validate_literal(literal: &str) -> Option<&str> { let trimmed = literal.trim(); if trimmed.starts_with('_') || trimmed.ends_with('_') { - return Err(invalid_literal(vm, literal, base)); + return None; } - let mut buf = String::with_capacity(trimmed.len()); let mut last_tok = None; for c in trimmed.chars() { if !(c.is_ascii_alphanumeric() || c == '_' || c == '+' || c == '-') { - return Err(invalid_literal(vm, literal, base)); + return None; } if c == '_' && Some(c) == last_tok { - return Err(invalid_literal(vm, literal, base)); + return None; } last_tok = Some(c); - buf.push(c); } - Ok(buf) + Some(trimmed) } fn detect_base(literal: &str) -> Option { @@ -862,13 +860,6 @@ fn detect_base(literal: &str) -> Option { } } -fn invalid_literal(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyBaseExceptionRef { - vm.new_value_error(format!( - "invalid literal for int() with base {}: '{}'", - base, literal - )) -} - // Retrieve inner int value: pub fn get_value(obj: &PyObjectRef) -> &BigInt { &obj.payload::().unwrap().value