@@ -105,9 +105,9 @@ fn nid2obj(nid: Nid) -> Option<Asn1Object> {
105
105
unsafe { ptr2obj ( sys:: OBJ_nid2obj ( nid. as_raw ( ) ) ) }
106
106
}
107
107
fn obj2txt ( obj : & Asn1ObjectRef , no_name : bool ) -> Option < String > {
108
- unsafe {
109
- let no_name = if no_name { 1 } else { 0 } ;
110
- let ptr = obj . as_ptr ( ) ;
108
+ let no_name = if no_name { 1 } else { 0 } ;
109
+ let ptr = obj . as_ptr ( ) ;
110
+ let s = unsafe {
111
111
let buflen = sys:: OBJ_obj2txt ( std:: ptr:: null_mut ( ) , 0 , ptr, no_name) ;
112
112
assert ! ( buflen >= 0 ) ;
113
113
if buflen == 0 {
@@ -116,10 +116,10 @@ fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
116
116
let mut buf = vec ! [ 0u8 ; buflen as usize ] ;
117
117
let ret = sys:: OBJ_obj2txt ( buf. as_mut_ptr ( ) as * mut libc:: c_char , buflen, ptr, no_name) ;
118
118
assert ! ( ret >= 0 ) ;
119
- let s = String :: from_utf8 ( buf)
120
- . unwrap_or_else ( |e| String :: from_utf8_lossy ( e. as_bytes ( ) ) . into_owned ( ) ) ;
121
- Some ( s )
122
- }
119
+ String :: from_utf8 ( buf)
120
+ . unwrap_or_else ( |e| String :: from_utf8_lossy ( e. as_bytes ( ) ) . into_owned ( ) )
121
+ } ;
122
+ Some ( s )
123
123
}
124
124
125
125
type PyNid = ( libc:: c_int , String , String , Option < String > ) ;
@@ -232,9 +232,8 @@ fn _ssl_rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
232
232
return Err ( vm. new_value_error ( "num must be positive" . to_owned ( ) ) ) ;
233
233
}
234
234
let mut buf = vec ! [ 0 ; n as usize ] ;
235
- openssl:: rand:: rand_bytes ( & mut buf)
236
- . map ( |( ) | buf)
237
- . map_err ( |e| convert_openssl_error ( vm, e) )
235
+ openssl:: rand:: rand_bytes ( & mut buf) . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
236
+ Ok ( buf)
238
237
}
239
238
240
239
fn _ssl_rand_pseudo_bytes ( n : i32 , vm : & VirtualMachine ) -> PyResult < ( Vec < u8 > , bool ) > {
@@ -642,67 +641,70 @@ struct LoadCertChainArgs {
642
641
password : Option < Either < PyStrRef , ArgCallable > > ,
643
642
}
644
643
645
- struct SocketTimeout {
646
- // Err is true if the socket is blocking
647
- deadline : Result < Instant , bool > ,
648
- }
649
- impl SocketTimeout {
650
- fn get ( s : & SocketStream ) -> Self {
651
- let deadline = s. 0 . get_timeout ( ) . map ( |d| Instant :: now ( ) + d) ;
652
- Self { deadline }
653
- }
654
- }
644
+ // Err is true if the socket is blocking
645
+ type SocketDeadline = Result < Instant , bool > ;
646
+
655
647
enum SelectRet {
656
648
Nonblocking ,
657
649
TimedOut ,
658
650
IsBlocking ,
659
651
Closed ,
660
652
Ok ,
661
653
}
662
- fn ssl_select ( sock : & SocketStream , needs : SslNeeds , timeout : & SocketTimeout ) -> SelectRet {
663
- let sock = match sock. 0 . sock_opt ( ) {
664
- Some ( s) => s,
665
- None => return SelectRet :: Closed ,
666
- } ;
667
- let timeout = match & timeout. deadline {
668
- Ok ( deadline) => match deadline. checked_duration_since ( Instant :: now ( ) ) {
669
- Some ( timeout) => timeout,
670
- None => return SelectRet :: TimedOut ,
671
- } ,
672
- Err ( true ) => return SelectRet :: IsBlocking ,
673
- Err ( false ) => return SelectRet :: Nonblocking ,
674
- } ;
675
- let res = socket:: sock_select (
676
- & sock,
677
- match needs {
678
- SslNeeds :: Read => socket:: SelectKind :: Read ,
679
- SslNeeds :: Write => socket:: SelectKind :: Write ,
680
- } ,
681
- Some ( timeout) ,
682
- ) ;
683
- match res {
684
- Ok ( true ) => SelectRet :: TimedOut ,
685
- _ => SelectRet :: Ok ,
686
- }
687
- }
654
+
688
655
#[ derive( Clone , Copy ) ]
689
656
enum SslNeeds {
690
657
Read ,
691
658
Write ,
692
659
}
693
660
694
- fn socket_needs (
695
- err : & ssl:: Error ,
696
- sock : & SocketStream ,
697
- timeout : & SocketTimeout ,
698
- ) -> ( Option < SslNeeds > , SelectRet ) {
699
- let needs = match err. code ( ) {
700
- ssl:: ErrorCode :: WANT_READ => Some ( SslNeeds :: Read ) ,
701
- ssl:: ErrorCode :: WANT_WRITE => Some ( SslNeeds :: Write ) ,
702
- _ => None ,
703
- } ;
704
- let state = needs. map_or ( SelectRet :: Ok , |needs| ssl_select ( sock, needs, timeout) ) ;
705
- ( needs, state)
661
+ struct SocketStream ( PyRef < PySocket > ) ;
662
+
663
+ impl SocketStream {
664
+ fn timeout_deadline ( & self ) -> SocketDeadline {
665
+ self . 0 . get_timeout ( ) . map ( |d| Instant :: now ( ) + d)
666
+ }
667
+
668
+ fn select ( & self , needs : SslNeeds , deadline : & SocketDeadline ) -> SelectRet {
669
+ let sock = match self . 0 . sock_opt ( ) {
670
+ Some ( s) => s,
671
+ None => return SelectRet :: Closed ,
672
+ } ;
673
+ let deadline = match & deadline {
674
+ Ok ( deadline) => match deadline. checked_duration_since ( Instant :: now ( ) ) {
675
+ Some ( deadline) => deadline,
676
+ None => return SelectRet :: TimedOut ,
677
+ } ,
678
+ Err ( true ) => return SelectRet :: IsBlocking ,
679
+ Err ( false ) => return SelectRet :: Nonblocking ,
680
+ } ;
681
+ let res = socket:: sock_select (
682
+ & sock,
683
+ match needs {
684
+ SslNeeds :: Read => socket:: SelectKind :: Read ,
685
+ SslNeeds :: Write => socket:: SelectKind :: Write ,
686
+ } ,
687
+ Some ( deadline) ,
688
+ ) ;
689
+ match res {
690
+ Ok ( true ) => SelectRet :: TimedOut ,
691
+ _ => SelectRet :: Ok ,
692
+ }
693
+ }
694
+
695
+ fn socket_needs (
696
+ & self ,
697
+ err : & ssl:: Error ,
698
+ deadline : & SocketDeadline ,
699
+ ) -> ( Option < SslNeeds > , SelectRet ) {
700
+ let needs = match err. code ( ) {
701
+ ssl:: ErrorCode :: WANT_READ => Some ( SslNeeds :: Read ) ,
702
+ ssl:: ErrorCode :: WANT_WRITE => Some ( SslNeeds :: Write ) ,
703
+ _ => None ,
704
+ } ;
705
+ let state = needs. map_or ( SelectRet :: Ok , |needs| self . select ( needs, deadline) ) ;
706
+ ( needs, state)
707
+ }
706
708
}
707
709
708
710
fn socket_closed_error ( vm : & VirtualMachine ) -> PyBaseExceptionRef {
@@ -788,38 +790,37 @@ impl PySslSocket {
788
790
. map ( cipher_to_tuple)
789
791
}
790
792
793
+ #[ cfg( osslconf = "OPENSSL_NO_COMP" ) ]
791
794
#[ pymethod]
792
795
fn compression ( & self ) -> Option < & ' static str > {
793
- #[ cfg( osslconf = "OPENSSL_NO_COMP" ) ]
794
- {
795
- None
796
+ None
797
+ }
798
+ #[ cfg( not( osslconf = "OPENSSL_NO_COMP" ) ) ]
799
+ #[ pymethod]
800
+ fn compression ( & self ) -> Option < & ' static str > {
801
+ let stream = self . stream . read ( ) ;
802
+ let comp_method = unsafe { sys:: SSL_get_current_compression ( stream. ssl ( ) . as_ptr ( ) ) } ;
803
+ if comp_method. is_null ( ) {
804
+ return None ;
796
805
}
797
- #[ cfg( not( osslconf = "OPENSSL_NO_COMP" ) ) ]
798
- {
799
- let stream = self . stream . read ( ) ;
800
- let comp_method = unsafe { sys:: SSL_get_current_compression ( stream. ssl ( ) . as_ptr ( ) ) } ;
801
- if comp_method. is_null ( ) {
802
- return None ;
803
- }
804
- let typ = unsafe { sys:: COMP_get_type ( comp_method) } ;
805
- let nid = Nid :: from_raw ( typ) ;
806
- if nid == Nid :: UNDEF {
807
- return None ;
808
- }
809
- nid. short_name ( ) . ok ( )
806
+ let typ = unsafe { sys:: COMP_get_type ( comp_method) } ;
807
+ let nid = Nid :: from_raw ( typ) ;
808
+ if nid == Nid :: UNDEF {
809
+ return None ;
810
810
}
811
+ nid. short_name ( ) . ok ( )
811
812
}
812
813
813
814
#[ pymethod]
814
815
fn do_handshake ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
815
816
let mut stream = self . stream . write ( ) ;
816
- let timeout = SocketTimeout :: get ( stream. get_ref ( ) ) ;
817
+ let timeout = stream. get_ref ( ) . timeout_deadline ( ) ;
817
818
loop {
818
819
let err = match stream. do_handshake ( ) {
819
820
Ok ( ( ) ) => return Ok ( ( ) ) ,
820
821
Err ( e) => e,
821
822
} ;
822
- let ( needs, state) = socket_needs ( & err , & stream. get_ref ( ) , & timeout) ;
823
+ let ( needs, state) = stream. get_ref ( ) . socket_needs ( & err , & timeout) ;
823
824
match state {
824
825
SelectRet :: TimedOut => {
825
826
return Err ( socket:: timeout_error_msg (
@@ -844,8 +845,8 @@ impl PySslSocket {
844
845
let mut stream = self . stream . write ( ) ;
845
846
let data = data. borrow_buf ( ) ;
846
847
let data = & * data;
847
- let timeout = SocketTimeout :: get ( stream. get_ref ( ) ) ;
848
- let state = ssl_select ( stream. get_ref ( ) , SslNeeds :: Write , & timeout) ;
848
+ let timeout = stream. get_ref ( ) . timeout_deadline ( ) ;
849
+ let state = stream. get_ref ( ) . select ( SslNeeds :: Write , & timeout) ;
849
850
match state {
850
851
SelectRet :: TimedOut => {
851
852
return Err ( socket:: timeout_error_msg (
@@ -861,7 +862,7 @@ impl PySslSocket {
861
862
Ok ( len) => return Ok ( len) ,
862
863
Err ( e) => e,
863
864
} ;
864
- let ( needs, state) = socket_needs ( & err , stream. get_ref ( ) , & timeout) ;
865
+ let ( needs, state) = stream. get_ref ( ) . socket_needs ( & err , & timeout) ;
865
866
match state {
866
867
SelectRet :: TimedOut => {
867
868
return Err ( socket:: timeout_error_msg (
@@ -902,7 +903,7 @@ impl PySslSocket {
902
903
Some ( b) => b,
903
904
None => buf,
904
905
} ;
905
- let timeout = SocketTimeout :: get ( stream. get_ref ( ) ) ;
906
+ let timeout = stream. get_ref ( ) . timeout_deadline ( ) ;
906
907
let count = loop {
907
908
let err = match stream. ssl_read ( buf) {
908
909
Ok ( count) => break count,
@@ -913,7 +914,7 @@ impl PySslSocket {
913
914
{
914
915
break 0 ;
915
916
}
916
- let ( needs, state) = socket_needs ( & err , stream. get_ref ( ) , & timeout) ;
917
+ let ( needs, state) = stream. get_ref ( ) . socket_needs ( & err , & timeout) ;
917
918
match state {
918
919
SelectRet :: TimedOut => {
919
920
return Err ( socket:: timeout_error_msg (
@@ -996,10 +997,9 @@ fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple {
996
997
}
997
998
998
999
fn cert_to_py ( vm : & VirtualMachine , cert : & X509Ref , binary : bool ) -> PyResult {
999
- if binary {
1000
- cert. to_der ( )
1001
- . map ( |b| vm. ctx . new_bytes ( b) )
1002
- . map_err ( |e| convert_openssl_error ( vm, e) )
1000
+ let r = if binary {
1001
+ let b = cert. to_der ( ) . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
1002
+ vm. ctx . new_bytes ( b)
1003
1003
} else {
1004
1004
let dict = vm. ctx . new_dict ( ) ;
1005
1005
@@ -1073,8 +1073,9 @@ fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
1073
1073
dict. set_item ( "subjectAltName" , vm. ctx . new_tuple ( san) , vm) ?;
1074
1074
} ;
1075
1075
1076
- Ok ( dict. into_object ( ) )
1077
- }
1076
+ dict. into_object ( )
1077
+ } ;
1078
+ Ok ( r)
1078
1079
}
1079
1080
1080
1081
#[ allow( non_snake_case) ]
@@ -1238,8 +1239,6 @@ fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) {
1238
1239
#[ cfg( not( windows) ) ]
1239
1240
fn extend_module_platform_specific ( _module : & PyObjectRef , _vm : & VirtualMachine ) { }
1240
1241
1241
- struct SocketStream ( PyRef < PySocket > ) ;
1242
-
1243
1242
impl std:: io:: Read for SocketStream {
1244
1243
fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
1245
1244
<& socket2:: Socket as std:: io:: Read >:: read ( & mut & * self . 0 . sock_io ( ) ?, buf)
0 commit comments