@@ -22,33 +22,50 @@ use openssl::{
22
22
error:: ErrorStack ,
23
23
nid:: Nid ,
24
24
ssl:: { self , SslContextBuilder , SslOptions , SslVerifyMode } ,
25
- x509:: { self , X509Object , X509Ref , X509 } ,
25
+ x509:: { self , X509Ref , X509 } ,
26
26
} ;
27
27
use std:: convert:: TryFrom ;
28
28
use std:: ffi:: CStr ;
29
29
use std:: fmt;
30
30
use std:: io:: { Read , Write } ;
31
31
use std:: time:: Instant ;
32
32
33
- mod sys {
34
- #![ allow( non_camel_case_types, unused) ]
35
- use libc:: { c_char, c_double, c_int, c_long, c_void} ;
36
- pub use openssl_sys:: * ;
37
- extern "C" {
38
- pub fn OBJ_txt2obj ( s : * const c_char , no_name : c_int ) -> * mut ASN1_OBJECT ;
39
- pub fn OBJ_nid2obj ( n : c_int ) -> * mut ASN1_OBJECT ;
40
- pub fn X509_get_default_cert_file_env ( ) -> * const c_char ;
41
- pub fn X509_get_default_cert_file ( ) -> * const c_char ;
42
- pub fn X509_get_default_cert_dir_env ( ) -> * const c_char ;
43
- pub fn X509_get_default_cert_dir ( ) -> * const c_char ;
44
- #[ cfg( ossl111) ]
45
- pub fn SSL_CTX_set_post_handshake_auth ( ctx : * mut SSL_CTX , val : c_int ) ;
46
- pub fn RAND_add ( buf : * const c_void , num : c_int , randomness : c_double ) ;
47
- pub fn RAND_pseudo_bytes ( buf : * const u8 , num : c_int ) -> c_int ;
48
- pub fn X509_get_version ( x : * const X509 ) -> c_long ;
49
- pub fn SSLv3_method ( ) -> * const SSL_METHOD ;
50
- pub fn TLSv1_method ( ) -> * const SSL_METHOD ;
51
- pub fn COMP_get_type ( meth : * const COMP_METHOD ) -> i32 ;
33
+ use openssl_sys as sys;
34
+
35
+ mod bio {
36
+ //! based off rust-openssl's private `bio` module
37
+
38
+ use super :: * ;
39
+
40
+ use libc:: c_int;
41
+ use std:: marker:: PhantomData ;
42
+
43
+ pub struct MemBioSlice < ' a > ( * mut sys:: BIO , PhantomData < & ' a [ u8 ] > ) ;
44
+
45
+ impl < ' a > Drop for MemBioSlice < ' a > {
46
+ fn drop ( & mut self ) {
47
+ unsafe {
48
+ sys:: BIO_free_all ( self . 0 ) ;
49
+ }
50
+ }
51
+ }
52
+
53
+ impl < ' a > MemBioSlice < ' a > {
54
+ pub fn new ( buf : & ' a [ u8 ] ) -> Result < MemBioSlice < ' a > , ErrorStack > {
55
+ openssl:: init ( ) ;
56
+
57
+ assert ! ( buf. len( ) <= c_int:: max_value( ) as usize ) ;
58
+ let bio = unsafe { sys:: BIO_new_mem_buf ( buf. as_ptr ( ) as * const _ , buf. len ( ) as c_int ) } ;
59
+ if bio. is_null ( ) {
60
+ return Err ( ErrorStack :: get ( ) ) ;
61
+ }
62
+
63
+ Ok ( MemBioSlice ( bio, PhantomData ) )
64
+ }
65
+
66
+ pub fn as_ptr ( & self ) -> * mut sys:: BIO {
67
+ self . 0
68
+ }
52
69
}
53
70
}
54
71
@@ -119,9 +136,17 @@ fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
119
136
if buflen == 0 {
120
137
return None ;
121
138
}
122
- let mut buf = vec ! [ 0u8 ; buflen as usize ] ;
123
- let ret = sys:: OBJ_obj2txt ( buf. as_mut_ptr ( ) as * mut libc:: c_char , buflen, ptr, no_name) ;
139
+ let buflen = buflen as usize ;
140
+ let mut buf = Vec :: < u8 > :: with_capacity ( buflen + 1 ) ;
141
+ let ret = sys:: OBJ_obj2txt (
142
+ buf. as_mut_ptr ( ) as * mut libc:: c_char ,
143
+ buf. capacity ( ) as _ ,
144
+ ptr,
145
+ no_name,
146
+ ) ;
124
147
assert ! ( ret >= 0 ) ;
148
+ // SAFETY: set_len is safe when capacity is enoguh and all values are already initialized
149
+ buf. set_len ( buflen) ;
125
150
String :: from_utf8 ( buf)
126
151
. unwrap_or_else ( |e| String :: from_utf8_lossy ( e. as_bytes ( ) ) . into_owned ( ) )
127
152
} ;
@@ -243,7 +268,7 @@ fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, boo
243
268
return Err ( vm. new_value_error ( "num must be positive" . to_owned ( ) ) ) ;
244
269
}
245
270
let mut buf = vec ! [ 0 ; n as usize ] ;
246
- let ret = unsafe { sys:: RAND_pseudo_bytes ( buf. as_mut_ptr ( ) , n) } ;
271
+ let ret = unsafe { sys:: RAND_bytes ( buf. as_mut_ptr ( ) , n) } ;
247
272
match ret {
248
273
0 | 1 => Ok ( ( buf, ret == 1 ) ) ,
249
274
_ => Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ,
@@ -277,7 +302,6 @@ impl SlotConstructor for PySslContext {
277
302
let method = match proto {
278
303
// SslVersion::Ssl3 => unsafe { ssl::SslMethod::from_ptr(sys::SSLv3_method()) },
279
304
SslVersion :: Tls => ssl:: SslMethod :: tls ( ) ,
280
- SslVersion :: Tls1 => unsafe { ssl:: SslMethod :: from_ptr ( sys:: TLSv1_method ( ) ) } ,
281
305
// TODO: Tls1_1, Tls1_2 ?
282
306
SslVersion :: TlsClient => ssl:: SslMethod :: tls_client ( ) ,
283
307
SslVersion :: TlsServer => ssl:: SslMethod :: tls_server ( ) ,
@@ -461,22 +485,22 @@ impl PySslContext {
461
485
}
462
486
463
487
if let Some ( cadata) = args. cadata {
464
- let cert = match cadata {
488
+ let certs = match cadata {
465
489
Either :: A ( s) => {
466
- if !s. as_str ( ) . is_ascii ( ) {
490
+ if !s. is_ascii ( ) {
467
491
return Err ( vm. new_type_error ( "Must be an ascii string" . to_owned ( ) ) ) ;
468
492
}
469
- X509 :: from_pem ( s. as_str ( ) . as_bytes ( ) )
493
+ X509 :: stack_from_pem ( s. as_str ( ) . as_bytes ( ) )
470
494
}
471
- Either :: B ( b) => b. with_ref ( X509 :: from_der ) ,
495
+ Either :: B ( b) => b. with_ref ( x509_stack_from_der ) ,
472
496
} ;
473
- let cert = cert . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
474
- let ret = self . exec_ctx ( |ctx| {
475
- let store = ctx. cert_store ( ) ;
476
- unsafe { sys :: X509_STORE_add_cert ( store . as_ptr ( ) , cert . as_ptr ( ) ) }
477
- } ) ;
478
- if ret <= 0 {
479
- return Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ;
497
+ let certs = certs . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
498
+ let mut ctx = self . builder ( ) ;
499
+ let store = ctx. cert_store_mut ( ) ;
500
+ for cert in certs {
501
+ store
502
+ . add_cert ( cert )
503
+ . map_err ( |e| convert_openssl_error ( vm, e ) ) ? ;
480
504
}
481
505
}
482
506
@@ -511,22 +535,17 @@ impl PySslContext {
511
535
512
536
#[ pymethod]
513
537
fn get_ca_certs ( & self , binary_form : OptionalArg < bool > , vm : & VirtualMachine ) -> PyResult {
514
- use openssl:: stack:: StackRef ;
515
538
let binary_form = binary_form. unwrap_or ( false ) ;
516
- let certs = unsafe {
517
- let stack =
518
- sys:: X509_STORE_get0_objects ( self . exec_ctx ( |ctx| ctx. cert_store ( ) . as_ptr ( ) ) ) ;
519
- assert ! ( !stack. is_null( ) ) ;
520
- StackRef :: < X509Object > :: from_ptr ( stack)
521
- } ;
522
- let certs = certs
523
- . iter ( )
524
- . filter_map ( |cert| {
525
- let cert = cert. x509 ( ) ?;
526
- Some ( cert_to_py ( vm, cert, binary_form) )
527
- } )
528
- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
529
- Ok ( vm. ctx . new_list ( certs) )
539
+ self . exec_ctx ( |ctx| {
540
+ let certs = ctx
541
+ . cert_store ( )
542
+ . objects ( )
543
+ . iter ( )
544
+ . filter_map ( |obj| obj. x509 ( ) )
545
+ . map ( |cert| cert_to_py ( vm, cert, binary_form) )
546
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
547
+ Ok ( vm. ctx . new_list ( certs) )
548
+ } )
530
549
}
531
550
532
551
#[ pymethod]
@@ -949,19 +968,30 @@ fn ssl_error(vm: &VirtualMachine) -> PyTypeRef {
949
968
vm. class ( "_ssl" , "SSLError" )
950
969
}
951
970
971
+ #[ track_caller]
952
972
fn convert_openssl_error ( vm : & VirtualMachine , err : ErrorStack ) -> PyBaseExceptionRef {
953
973
let cls = ssl_error ( vm) ;
954
974
match err. errors ( ) . last ( ) {
955
975
Some ( e) => {
976
+ let caller = std:: panic:: Location :: caller ( ) ;
977
+ let ( file, line) = ( caller. file ( ) , caller. line ( ) ) ;
978
+ let file = file
979
+ . rsplit_once ( & [ '/' , '\\' ] [ ..] )
980
+ . map_or ( file, |( _, basename) | basename) ;
956
981
// TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict
957
982
let errstr = e. reason ( ) . unwrap_or ( "unknown error" ) ;
958
- let msg = format ! ( "{} (_ssl.c:{})" , errstr, e. line( ) ) ;
983
+ let msg = if let Some ( lib) = e. library ( ) {
984
+ format ! ( "[{}] {} ({}:{})" , lib, errstr, file, line)
985
+ } else {
986
+ format ! ( "{} ({}:{})" , errstr, file, line)
987
+ } ;
959
988
let reason = sys:: ERR_GET_REASON ( e. code ( ) ) ;
960
989
vm. new_exception ( cls, vec ! [ vm. ctx. new_int( reason) , vm. ctx. new_utf8_str( msg) ] )
961
990
}
962
991
None => vm. new_exception_empty ( cls) ,
963
992
}
964
993
}
994
+ #[ track_caller]
965
995
fn convert_ssl_error (
966
996
vm : & VirtualMachine ,
967
997
e : impl std:: borrow:: Borrow < ssl:: Error > ,
@@ -992,6 +1022,33 @@ fn convert_ssl_error(
992
1022
vm. new_exception_msg ( cls, msg. to_owned ( ) )
993
1023
}
994
1024
1025
+ fn x509_stack_from_der ( der : & [ u8 ] ) -> Result < Vec < X509 > , ErrorStack > {
1026
+ unsafe {
1027
+ openssl:: init ( ) ;
1028
+ let bio = bio:: MemBioSlice :: new ( der) ?;
1029
+
1030
+ let mut certs = vec ! [ ] ;
1031
+ loop {
1032
+ let r = sys:: d2i_X509_bio ( bio. as_ptr ( ) , std:: ptr:: null_mut ( ) ) ;
1033
+ if r. is_null ( ) {
1034
+ let err = sys:: ERR_peek_last_error ( ) ;
1035
+ if sys:: ERR_GET_LIB ( err) == sys:: ERR_LIB_ASN1
1036
+ && sys:: ERR_GET_REASON ( err) == sys:: ASN1_R_HEADER_TOO_LONG
1037
+ {
1038
+ sys:: ERR_clear_error ( ) ;
1039
+ break ;
1040
+ }
1041
+
1042
+ return Err ( ErrorStack :: get ( ) ) ;
1043
+ } else {
1044
+ certs. push ( X509 :: from_ptr ( r) ) ;
1045
+ }
1046
+ }
1047
+
1048
+ Ok ( certs)
1049
+ }
1050
+ }
1051
+
995
1052
type CipherTuple = ( & ' static str , & ' static str , i32 ) ;
996
1053
997
1054
fn cipher_to_tuple ( cipher : & ssl:: SslCipherRef ) -> CipherTuple {
@@ -1019,9 +1076,7 @@ fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
1019
1076
1020
1077
dict. set_item ( "subject" , name_to_py ( cert. subject_name ( ) ) ?, vm) ?;
1021
1078
dict. set_item ( "issuer" , name_to_py ( cert. issuer_name ( ) ) ?, vm) ?;
1022
-
1023
- let version = unsafe { sys:: X509_get_version ( cert. as_ptr ( ) ) } ;
1024
- dict. set_item ( "version" , vm. ctx . new_int ( version) , vm) ?;
1079
+ dict. set_item ( "version" , vm. ctx . new_int ( cert. version ( ) ) , vm) ?;
1025
1080
1026
1081
let serial_num = cert
1027
1082
. serial_number ( )
0 commit comments