@@ -12,19 +12,22 @@ use crate::pyobject::{
12
12
use crate :: types:: create_type;
13
13
use crate :: VirtualMachine ;
14
14
15
- use std:: cell:: { RefCell , RefMut } ;
15
+ use std:: cell:: { Ref , RefCell , RefMut } ;
16
16
use std:: convert:: TryFrom ;
17
17
use std:: ffi:: { CStr , CString } ;
18
18
use std:: fmt;
19
19
20
- use foreign_types_shared :: { ForeignType , ForeignTypeRef } ;
20
+ use foreign_types :: { ForeignType , ForeignTypeRef } ;
21
21
use openssl:: {
22
22
asn1:: { Asn1Object , Asn1ObjectRef } ,
23
+ error:: ErrorStack ,
23
24
nid:: Nid ,
24
25
ssl:: { self , SslContextBuilder , SslOptions , SslVerifyMode } ,
26
+ x509:: { X509Ref , X509 } ,
25
27
} ;
26
28
27
29
mod sys {
30
+ #![ allow( non_camel_case_types, unused) ]
28
31
use libc:: { c_char, c_double, c_int, c_void} ;
29
32
pub use openssl_sys:: * ;
30
33
extern "C" {
@@ -40,7 +43,54 @@ mod sys {
40
43
pub fn SSL_CTX_set_post_handshake_auth ( ctx : * mut SSL_CTX , val : c_int ) ;
41
44
pub fn RAND_add ( buf : * const c_void , num : c_int , randomness : c_double ) ;
42
45
pub fn RAND_pseudo_bytes ( buf : * const u8 , num : c_int ) -> c_int ;
46
+ pub fn X509_STORE_get0_objects ( ctx : * mut X509_STORE ) -> * mut stack_st_X509_OBJECT ;
47
+ pub fn X509_OBJECT_free ( a : * mut X509_OBJECT ) ;
43
48
}
49
+
50
+ pub enum stack_st_X509_OBJECT { }
51
+
52
+ pub type X509_LOOKUP_TYPE = c_int ;
53
+ pub const X509_LU_NONE : X509_LOOKUP_TYPE = 0 ;
54
+ pub const X509_LU_X509 : X509_LOOKUP_TYPE = 1 ;
55
+ pub const X509_LU_CRL : X509_LOOKUP_TYPE = 2 ;
56
+
57
+ #[ repr( C ) ]
58
+ pub struct X509_OBJECT {
59
+ pub r#type : X509_LOOKUP_TYPE ,
60
+ pub data : X509_OBJECT_data ,
61
+ }
62
+ #[ repr( C ) ]
63
+ pub union X509_OBJECT_data {
64
+ pub ptr : * mut c_char ,
65
+ pub x509 : * mut X509 ,
66
+ pub crl : * mut X509_CRL ,
67
+ pub pkey : * mut EVP_PKEY ,
68
+ }
69
+ }
70
+
71
+ // TODO: upstream this into rust-openssl
72
+ foreign_types:: foreign_type! {
73
+ type CType = sys:: X509_OBJECT ;
74
+ fn drop = sys:: X509_OBJECT_free ;
75
+
76
+ pub struct X509Object ;
77
+ pub struct X509ObjectRef ;
78
+ }
79
+
80
+ impl X509ObjectRef {
81
+ fn x509 ( & self ) -> Option < & X509Ref > {
82
+ let ptr = self . as_ptr ( ) ;
83
+ let ty = unsafe { ( * ptr) . r#type } ;
84
+ if ty == sys:: X509_LU_X509 {
85
+ Some ( unsafe { X509Ref :: from_ptr ( ( * ptr) . data . x509 ) } )
86
+ } else {
87
+ None
88
+ }
89
+ }
90
+ }
91
+
92
+ impl openssl:: stack:: Stackable for X509Object {
93
+ type StackType = sys:: stack_st_X509_OBJECT ;
44
94
}
45
95
46
96
#[ derive( num_enum:: IntoPrimitive , num_enum:: TryFromPrimitive , PartialEq ) ]
@@ -224,7 +274,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool
224
274
let ret = unsafe { sys:: RAND_pseudo_bytes ( buf. as_mut_ptr ( ) , n) } ;
225
275
match ret {
226
276
0 | 1 => Ok ( ( buf, ret == 1 ) ) ,
227
- _ => Err ( convert_openssl_error ( vm, openssl :: error :: ErrorStack :: get ( ) ) ) ,
277
+ _ => Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ,
228
278
}
229
279
}
230
280
@@ -251,11 +301,11 @@ impl PySslContext {
251
301
fn builder ( & self ) -> RefMut < SslContextBuilder > {
252
302
self . ctx . borrow_mut ( )
253
303
}
254
- // fn ctx(&self) -> Ref<SslContextRef> {
255
- // Ref::map(self.ctx.borrow(), |ctx| unsafe {
256
- // SslContextRef::from_ptr(ctx.as_ptr() )
257
- // })
258
- // }
304
+ fn ctx ( & self ) -> Ref < ssl :: SslContextRef > {
305
+ Ref :: map ( self . ctx . borrow ( ) , |ctx| unsafe {
306
+ & * * ( ctx as * const SslContextBuilder as * const ssl :: SslContext )
307
+ } )
308
+ }
259
309
fn ptr ( & self ) -> * mut sys:: SSL_CTX {
260
310
self . ctx . borrow ( ) . as_ptr ( )
261
311
}
@@ -374,8 +424,23 @@ impl PySslContext {
374
424
) ;
375
425
}
376
426
377
- if let Some ( _cadata) = args. cadata {
378
- todo ! ( )
427
+ if let Some ( cadata) = args. cadata {
428
+ let cert = match cadata {
429
+ Either :: A ( s) => {
430
+ if !s. as_str ( ) . is_ascii ( ) {
431
+ return Err ( vm. new_type_error ( "Must be an ascii string" . to_owned ( ) ) ) ;
432
+ }
433
+ X509 :: from_pem ( s. as_str ( ) . as_bytes ( ) )
434
+ }
435
+ Either :: B ( b) => b. with_ref ( X509 :: from_der) ,
436
+ } ;
437
+ let cert = cert. map_err ( |e| convert_openssl_error ( vm, e) ) ?;
438
+ let ctx = self . ctx ( ) ;
439
+ let store = ctx. cert_store ( ) ;
440
+ let ret = unsafe { sys:: X509_STORE_add_cert ( store. as_ptr ( ) , cert. as_ptr ( ) ) } ;
441
+ if ret <= 0 {
442
+ return Err ( convert_openssl_error ( vm, ErrorStack :: get ( ) ) ) ;
443
+ }
379
444
}
380
445
381
446
if args. cafile . is_some ( ) || args. capath . is_some ( ) {
@@ -395,7 +460,7 @@ impl PySslContext {
395
460
let err = if errno != 0 {
396
461
super :: os:: errno_err ( vm)
397
462
} else {
398
- convert_openssl_error ( vm, openssl :: error :: ErrorStack :: get ( ) )
463
+ convert_openssl_error ( vm, ErrorStack :: get ( ) )
399
464
} ;
400
465
return Err ( err) ;
401
466
}
@@ -404,6 +469,32 @@ impl PySslContext {
404
469
Ok ( ( ) )
405
470
}
406
471
472
+ #[ pymethod]
473
+ fn get_ca_certs ( & self , binary_form : OptionalArg < bool > , vm : & VirtualMachine ) -> PyResult {
474
+ use openssl:: stack:: StackRef ;
475
+ let binary_form = binary_form. unwrap_or ( false ) ;
476
+ let certs = unsafe {
477
+ let stack = sys:: X509_STORE_get0_objects ( self . ctx ( ) . cert_store ( ) . as_ptr ( ) ) ;
478
+ assert ! ( !stack. is_null( ) ) ;
479
+ StackRef :: < X509Object > :: from_ptr ( stack)
480
+ } ;
481
+ let certs = certs
482
+ . iter ( )
483
+ . filter_map ( |cert| {
484
+ let cert = cert. x509 ( ) ?;
485
+ let obj = if binary_form {
486
+ cert. to_der ( )
487
+ . map ( |b| vm. ctx . new_bytes ( b) )
488
+ . map_err ( |e| convert_openssl_error ( vm, e) )
489
+ } else {
490
+ todo ! ( )
491
+ } ;
492
+ Some ( obj)
493
+ } )
494
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
495
+ Ok ( vm. ctx . new_list ( certs) )
496
+ }
497
+
407
498
#[ pymethod]
408
499
fn _wrap_socket (
409
500
zelf : PyRef < Self > ,
@@ -475,7 +566,7 @@ struct LoadVerifyLocationsArgs {
475
566
#[ pyarg( positional_or_keyword, default = "None" ) ]
476
567
capath : Option < CString > ,
477
568
#[ pyarg( positional_or_keyword, default = "None" ) ]
478
- cadata : Option < PyStringRef > ,
569
+ cadata : Option < Either < PyStringRef , PyBytesLike > > ,
479
570
}
480
571
481
572
#[ pyclass( name = "_SSLSocket" ) ]
@@ -591,19 +682,18 @@ fn ssl_error(vm: &VirtualMachine) -> PyClassRef {
591
682
vm. class ( "_ssl" , "SSLError" )
592
683
}
593
684
594
- fn convert_openssl_error (
595
- vm : & VirtualMachine ,
596
- err : openssl:: error:: ErrorStack ,
597
- ) -> PyBaseExceptionRef {
685
+ fn convert_openssl_error ( vm : & VirtualMachine , err : ErrorStack ) -> PyBaseExceptionRef {
598
686
let cls = ssl_error ( vm) ;
599
687
match err. errors ( ) . first ( ) {
600
688
Some ( e) => {
601
- let no = "unknown" ;
602
- let msg = format ! (
603
- "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}" ,
604
- e. code( ) , e. library( ) . unwrap_or( no) , e. function( ) . unwrap_or( no) , e. line( ) ,
605
- e. reason( ) . unwrap_or( no) , e. data( ) . unwrap_or( "none" ) ,
606
- ) ;
689
+ // let no = "unknown";
690
+ // let msg = format!(
691
+ // "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}",
692
+ // e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(),
693
+ // e.reason().unwrap_or(no), e.data().unwrap_or("none"),
694
+ // );
695
+ // TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict
696
+ let msg = e. to_string ( ) ;
607
697
vm. new_exception_msg ( cls, msg)
608
698
}
609
699
None => vm. new_exception_empty ( cls) ,
0 commit comments