@@ -13,7 +13,6 @@ use super::objfloat;
13
13
use super :: objmemory:: PyMemoryView ;
14
14
use super :: objstr:: { PyString , PyStringRef } ;
15
15
use super :: objtype:: { self , PyClassRef } ;
16
- use crate :: exceptions:: PyBaseExceptionRef ;
17
16
use crate :: format:: FormatSpec ;
18
17
use crate :: function:: { OptionalArg , PyFuncArgs } ;
19
18
use crate :: pyhash;
@@ -724,26 +723,21 @@ struct IntToByteArgs {
724
723
725
724
// Casting function:
726
725
pub fn to_int ( vm : & VirtualMachine , obj : & PyObjectRef , base : & BigInt ) -> PyResult < BigInt > {
727
- let base_u32 = match base. to_u32 ( ) {
728
- Some ( base_u32) => base_u32,
729
- None => {
730
- return Err ( vm. new_value_error ( "int() base must be >= 2 and <= 36, or 0" . to_owned ( ) ) )
731
- }
726
+ let base = match base. to_u32 ( ) {
727
+ Some ( base) if base == 0 || ( 2 ..=36 ) . contains ( & base) => base,
728
+ _ => return Err ( vm. new_value_error ( "int() base must be >= 2 and <= 36, or 0" . to_owned ( ) ) ) ,
732
729
} ;
733
- if base_u32 != 0 && ( base_u32 < 2 || base_u32 > 36 ) {
734
- return Err ( vm. new_value_error ( "int() base must be >= 2 and <= 36, or 0" . to_owned ( ) ) ) ;
735
- }
736
730
737
731
let bytes_to_int = |bytes : & [ u8 ] | {
738
- let s = std:: str:: from_utf8 ( bytes)
739
- . map_err ( |e| vm . new_value_error ( format ! ( "utf8 decode error: {}" , e ) ) ) ? ;
740
- str_to_int ( vm , s, base)
732
+ std:: str:: from_utf8 ( bytes)
733
+ . ok ( )
734
+ . and_then ( |s| str_to_int ( s, base) )
741
735
} ;
742
736
743
- match_class ! ( match obj. clone( ) {
737
+ let opt = match_class ! ( match obj. clone( ) {
744
738
string @ PyString => {
745
739
let s = string. as_str( ) ;
746
- str_to_int( vm , & s, base)
740
+ str_to_int( & s, base)
747
741
}
748
742
bytes @ PyBytes => {
749
743
let bytes = bytes. get_value( ) ;
@@ -770,36 +764,39 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: &BigInt) -> PyResult
770
764
)
771
765
} ) ?;
772
766
let result = vm. invoke( & method, PyFuncArgs :: default ( ) ) ?;
773
- match result. payload:: <PyInt >( ) {
767
+ return match result. payload:: <PyInt >( ) {
774
768
Some ( int_obj) => Ok ( int_obj. as_bigint( ) . clone( ) ) ,
775
769
None => Err ( vm. new_type_error( format!(
776
770
"TypeError: __int__ returned non-int (type '{}')" ,
777
771
result. class( ) . name
778
772
) ) ) ,
779
- }
773
+ } ;
780
774
}
781
- } )
775
+ } ) ;
776
+ match opt {
777
+ Some ( int) => Ok ( int) ,
778
+ None => Err ( vm. new_value_error ( format ! (
779
+ "invalid literal for int() with base {}: {}" ,
780
+ base,
781
+ vm. to_repr( obj) ?,
782
+ ) ) ) ,
783
+ }
782
784
}
783
785
784
- fn str_to_int ( vm : & VirtualMachine , literal : & str , base : & BigInt ) -> PyResult < BigInt > {
785
- let mut buf = validate_literal ( vm , literal, base ) ? ;
786
+ fn str_to_int ( literal : & str , mut base : u32 ) -> Option < BigInt > {
787
+ let mut buf = validate_literal ( literal) ? . to_owned ( ) ;
786
788
let is_signed = buf. starts_with ( '+' ) || buf. starts_with ( '-' ) ;
787
789
let radix_range = if is_signed { 1 ..3 } else { 0 ..2 } ;
788
790
let radix_candidate = buf. get ( radix_range. clone ( ) ) ;
789
791
790
- let mut base_u32 = match base. to_u32 ( ) {
791
- Some ( base_u32) => base_u32,
792
- None => return Err ( invalid_literal ( vm, literal, base) ) ,
793
- } ;
794
-
795
792
// try to find base
796
793
if let Some ( radix_candidate) = radix_candidate {
797
794
if let Some ( matched_radix) = detect_base ( & radix_candidate) {
798
- if base_u32 == 0 || base_u32 == matched_radix {
795
+ if base == 0 || base == matched_radix {
799
796
/* If base is 0 or equal radix number, it means radix is validate
800
797
* So change base to radix number and remove radix from literal
801
798
*/
802
- base_u32 = matched_radix;
799
+ base = matched_radix;
803
800
buf. drain ( radix_range) ;
804
801
805
802
/* first underscore with radix is validate
@@ -808,49 +805,50 @@ fn str_to_int(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult<Big
808
805
if buf. starts_with ( '_' ) {
809
806
buf. remove ( 0 ) ;
810
807
}
811
- } else if ( matched_radix == 2 && base_u32 < 12 )
812
- || ( matched_radix == 8 && base_u32 < 25 )
813
- || ( matched_radix == 16 && base_u32 < 34 )
808
+ } else if ( matched_radix == 2 && base < 12 )
809
+ || ( matched_radix == 8 && base < 25 )
810
+ || ( matched_radix == 16 && base < 34 )
814
811
{
815
- return Err ( invalid_literal ( vm , literal , base ) ) ;
812
+ return None ;
816
813
}
817
814
}
818
815
}
819
816
820
817
// base still not found, try to use default
821
- if base_u32 == 0 {
818
+ if base == 0 {
822
819
if buf. starts_with ( '0' ) {
823
- return Err ( invalid_literal ( vm, literal, base) ) ;
820
+ if buf. chars ( ) . all ( |c| matches ! ( c, '+' | '-' | '0' | '_' ) ) {
821
+ return Some ( BigInt :: zero ( ) ) ;
822
+ }
823
+ return None ;
824
824
}
825
825
826
- base_u32 = 10 ;
826
+ base = 10 ;
827
827
}
828
828
829
- BigInt :: from_str_radix ( & buf, base_u32 ) . map_err ( |_err| invalid_literal ( vm , literal , base ) )
829
+ BigInt :: from_str_radix ( & buf, base ) . ok ( )
830
830
}
831
831
832
- fn validate_literal ( vm : & VirtualMachine , literal : & str , base : & BigInt ) -> PyResult < String > {
832
+ fn validate_literal ( literal : & str ) -> Option < & str > {
833
833
let trimmed = literal. trim ( ) ;
834
834
if trimmed. starts_with ( '_' ) || trimmed. ends_with ( '_' ) {
835
- return Err ( invalid_literal ( vm , literal , base ) ) ;
835
+ return None ;
836
836
}
837
837
838
- let mut buf = String :: with_capacity ( trimmed. len ( ) ) ;
839
838
let mut last_tok = None ;
840
839
for c in trimmed. chars ( ) {
841
840
if !( c. is_ascii_alphanumeric ( ) || c == '_' || c == '+' || c == '-' ) {
842
- return Err ( invalid_literal ( vm , literal , base ) ) ;
841
+ return None ;
843
842
}
844
843
845
844
if c == '_' && Some ( c) == last_tok {
846
- return Err ( invalid_literal ( vm , literal , base ) ) ;
845
+ return None ;
847
846
}
848
847
849
848
last_tok = Some ( c) ;
850
- buf. push ( c) ;
851
849
}
852
850
853
- Ok ( buf )
851
+ Some ( trimmed )
854
852
}
855
853
856
854
fn detect_base ( literal : & str ) -> Option < u32 > {
@@ -862,13 +860,6 @@ fn detect_base(literal: &str) -> Option<u32> {
862
860
}
863
861
}
864
862
865
- fn invalid_literal ( vm : & VirtualMachine , literal : & str , base : & BigInt ) -> PyBaseExceptionRef {
866
- vm. new_value_error ( format ! (
867
- "invalid literal for int() with base {}: '{}'" ,
868
- base, literal
869
- ) )
870
- }
871
-
872
863
// Retrieve inner int value:
873
864
pub fn get_value ( obj : & PyObjectRef ) -> & BigInt {
874
865
& obj. payload :: < PyInt > ( ) . unwrap ( ) . value
0 commit comments