1
1
use std:: fmt;
2
+ use std:: str;
2
3
3
4
use num_bigint:: { BigInt , Sign } ;
4
5
use num_integer:: Integer ;
5
- use num_traits:: { One , Pow , Signed , ToPrimitive , Zero } ;
6
+ use num_traits:: { Num , One , Pow , Signed , ToPrimitive , Zero } ;
6
7
7
8
use crate :: format:: FormatSpec ;
8
9
use crate :: function:: { KwArgs , OptionalArg , PyFuncArgs } ;
@@ -713,7 +714,9 @@ impl IntOptions {
713
714
fn get_int_value ( self , vm : & VirtualMachine ) -> PyResult < BigInt > {
714
715
if let OptionalArg :: Present ( val) = self . val_options {
715
716
let base = if let OptionalArg :: Present ( base) = self . base {
716
- if !objtype:: isinstance ( & val, & vm. ctx . str_type ( ) ) {
717
+ if !( objtype:: isinstance ( & val, & vm. ctx . str_type ( ) )
718
+ || objtype:: isinstance ( & val, & vm. ctx . bytes_type ( ) ) )
719
+ {
717
720
return Err ( vm. new_type_error (
718
721
"int() can't convert non-string with explicit base" . to_string ( ) ,
719
722
) ) ;
@@ -736,21 +739,22 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul
736
739
}
737
740
738
741
// Casting function:
739
- pub fn to_int ( vm : & VirtualMachine , obj : & PyObjectRef , mut base : u32 ) -> PyResult < BigInt > {
740
- if base == 0 {
741
- base = 10
742
- } else if base < 2 || base > 36 {
742
+ pub fn to_int ( vm : & VirtualMachine , obj : & PyObjectRef , base : u32 ) -> PyResult < BigInt > {
743
+ if base != 0 && ( base < 2 || base > 36 ) {
743
744
return Err ( vm. new_value_error ( "int() base must be >= 2 and <= 36, or 0" . to_string ( ) ) ) ;
744
745
}
745
746
746
747
match_class ! ( obj. clone( ) ,
747
- s @ PyString => {
748
- i32 :: from_str_radix( s. as_str( ) . trim( ) , base)
749
- . map( BigInt :: from)
750
- . map_err( |_|vm. new_value_error( format!(
751
- "invalid literal for int() with base {}: '{}'" ,
752
- base, s
753
- ) ) )
748
+ string @ PyString => {
749
+ let s = string. value. as_str( ) . trim( ) ;
750
+ str_to_int( vm, s, base)
751
+ } ,
752
+ bytes @ PyBytes => {
753
+ let bytes = bytes. get_value( ) ;
754
+ let s = std:: str :: from_utf8( bytes)
755
+ . map( |s| s. trim( ) )
756
+ . map_err( |e| vm. new_value_error( format!( "utf8 decode error: {}" , e) ) ) ?;
757
+ str_to_int( vm, s, base)
754
758
} ,
755
759
obj => {
756
760
let method = vm. get_method_or_type_error( obj. clone( ) , "__int__" , || {
@@ -766,6 +770,76 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult
766
770
)
767
771
}
768
772
773
+ fn str_to_int ( vm : & VirtualMachine , literal : & str , mut base : u32 ) -> PyResult < BigInt > {
774
+ let mut buf = validate_literal ( vm, literal, base) ?;
775
+ let is_signed = buf. starts_with ( '+' ) || buf. starts_with ( '-' ) ;
776
+ let radix_range = if is_signed { 1 ..3 } else { 0 ..2 } ;
777
+ let radix_candidate = buf. get ( radix_range. clone ( ) ) ;
778
+
779
+ // try to find base
780
+ if let Some ( radix_candidate) = radix_candidate {
781
+ if let Some ( matched_radix) = detect_base ( & radix_candidate) {
782
+ if base != 0 && base != matched_radix {
783
+ return Err ( invalid_literal ( vm, literal, base) ) ;
784
+ } else {
785
+ base = matched_radix;
786
+ }
787
+
788
+ buf. drain ( radix_range) ;
789
+ }
790
+ }
791
+
792
+ // base still not found, try to use default
793
+ if base == 0 {
794
+ if buf. starts_with ( '0' ) {
795
+ return Err ( invalid_literal ( vm, literal, base) ) ;
796
+ }
797
+
798
+ base = 10 ;
799
+ }
800
+
801
+ BigInt :: from_str_radix ( & buf, base) . map_err ( |_err| invalid_literal ( vm, literal, base) )
802
+ }
803
+
804
+ fn validate_literal ( vm : & VirtualMachine , literal : & str , base : u32 ) -> PyResult < String > {
805
+ if literal. starts_with ( '_' ) || literal. ends_with ( '_' ) {
806
+ return Err ( invalid_literal ( vm, literal, base) ) ;
807
+ }
808
+
809
+ let mut buf = String :: with_capacity ( literal. len ( ) ) ;
810
+ let mut last_tok = None ;
811
+ for c in literal. chars ( ) {
812
+ if !( c. is_ascii_alphanumeric ( ) || c == '_' || c == '+' || c == '-' ) {
813
+ return Err ( invalid_literal ( vm, literal, base) ) ;
814
+ }
815
+
816
+ if c == '_' && Some ( c) == last_tok {
817
+ return Err ( invalid_literal ( vm, literal, base) ) ;
818
+ }
819
+
820
+ last_tok = Some ( c) ;
821
+ buf. push ( c) ;
822
+ }
823
+
824
+ Ok ( buf)
825
+ }
826
+
827
+ fn detect_base ( literal : & str ) -> Option < u32 > {
828
+ match literal {
829
+ "0x" | "0X" => Some ( 16 ) ,
830
+ "0o" | "0O" => Some ( 8 ) ,
831
+ "0b" | "0B" => Some ( 2 ) ,
832
+ _ => None ,
833
+ }
834
+ }
835
+
836
+ fn invalid_literal ( vm : & VirtualMachine , literal : & str , base : u32 ) -> PyObjectRef {
837
+ vm. new_value_error ( format ! (
838
+ "invalid literal for int() with base {}: '{}'" ,
839
+ base, literal
840
+ ) )
841
+ }
842
+
769
843
// Retrieve inner int value:
770
844
pub fn get_value ( obj : & PyObjectRef ) -> & BigInt {
771
845
& get_py_int ( obj) . value
0 commit comments