Skip to content

Commit 4fc45d2

Browse files
committed
Add _SSLSocket.peer_certificate
1 parent 4f64afb commit 4fc45d2

File tree

1 file changed

+101
-29
lines changed

1 file changed

+101
-29
lines changed

vm/src/stdlib/ssl.rs

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ use crate::exceptions::PyBaseExceptionRef;
33
use crate::function::OptionalArg;
44
use crate::obj::objbytearray::PyByteArrayRef;
55
use crate::obj::objbyteinner::PyBytesLike;
6-
use crate::obj::objbytes::PyBytesRef;
7-
use crate::obj::objstr::{PyString, PyStringRef};
6+
use crate::obj::objstr::PyStringRef;
87
use crate::obj::{objtype::PyClassRef, objweakref::PyWeak};
98
use crate::pyobject::{
10-
Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
9+
Either, IntoPyObject, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue,
1110
};
1211
use crate::types::create_type;
1312
use crate::VirtualMachine;
@@ -23,12 +22,12 @@ use openssl::{
2322
error::ErrorStack,
2423
nid::Nid,
2524
ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode},
26-
x509::{X509Ref, X509},
25+
x509::{self, X509Ref, X509},
2726
};
2827

2928
mod sys {
3029
#![allow(non_camel_case_types, unused)]
31-
use libc::{c_char, c_double, c_int, c_void};
30+
use libc::{c_char, c_double, c_int, c_long, c_void};
3231
pub use openssl_sys::*;
3332
extern "C" {
3433
pub fn OBJ_txt2obj(s: *const c_char, no_name: c_int) -> *mut ASN1_OBJECT;
@@ -45,6 +44,8 @@ mod sys {
4544
pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int;
4645
pub fn X509_STORE_get0_objects(ctx: *mut X509_STORE) -> *mut stack_st_X509_OBJECT;
4746
pub fn X509_OBJECT_free(a: *mut X509_OBJECT);
47+
pub fn SSL_is_init_finished(ssl: *const SSL) -> c_int;
48+
pub fn X509_get_version(x: *const X509) -> c_long;
4849
}
4950

5051
pub enum stack_st_X509_OBJECT {}
@@ -317,7 +318,7 @@ impl PySslContext {
317318
let method = match proto {
318319
SslVersion::Ssl2 => todo!(),
319320
SslVersion::Ssl3 => todo!(),
320-
SslVersion::Tls => unsafe { ssl::SslMethod::from_ptr(sys::TLS_method()) },
321+
SslVersion::Tls => ssl::SslMethod::tls(),
321322
SslVersion::Tls1 => todo!(),
322323
// TODO: Tls1_1, Tls1_2 ?
323324
SslVersion::TlsClient => unsafe { ssl::SslMethod::from_ptr(sys::TLS_client_method()) },
@@ -482,14 +483,7 @@ impl PySslContext {
482483
.iter()
483484
.filter_map(|cert| {
484485
let cert = cert.x509()?;
485-
let obj = if binary_form {
486-
cert.to_der()
487-
.map(|b| vm.ctx.new_bytes(b))
488-
.map_err(|e| convert_openssl_error(vm, e))
489-
} else {
490-
todo!()
491-
};
492-
Some(obj)
486+
Some(cert_to_py(vm, cert, binary_form))
493487
})
494488
.collect::<Result<Vec<_>, _>>()?;
495489
Ok(vm.ctx.new_list(certs))
@@ -501,18 +495,6 @@ impl PySslContext {
501495
args: WrapSocketArgs,
502496
vm: &VirtualMachine,
503497
) -> PyResult<PySslSocket> {
504-
let server_hostname = args
505-
.server_hostname
506-
.map(|s| {
507-
vm.encode(
508-
s.into_object(),
509-
Some(PyString::from("ascii").into_ref(vm)),
510-
None,
511-
)
512-
.and_then(|res| PyBytesRef::try_from_object(vm, res))
513-
})
514-
.transpose()?;
515-
516498
let ssl = {
517499
let ptr = zelf.ptr();
518500
let ctx = unsafe { ssl::SslContext::from_ptr(ptr) };
@@ -538,7 +520,7 @@ impl PySslContext {
538520
ctx: zelf,
539521
stream: RefCell::new(Some(stream)),
540522
socket_type,
541-
server_hostname,
523+
server_hostname: args.server_hostname,
542524
owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)),
543525
})
544526
}
@@ -574,7 +556,7 @@ struct PySslSocket {
574556
ctx: PyRef<PySslContext>,
575557
stream: RefCell<Option<ssl::SslStreamBuilder<PySocketRef>>>,
576558
socket_type: SslServerOrClient,
577-
server_hostname: Option<PyBytesRef>,
559+
server_hostname: Option<PyStringRef>,
578560
owner: RefCell<Option<PyWeak>>,
579561
}
580562

@@ -625,10 +607,28 @@ impl PySslSocket {
625607
self.ctx.clone()
626608
}
627609
#[pyproperty]
628-
fn server_hostname(&self) -> Option<PyBytesRef> {
610+
fn server_hostname(&self) -> Option<PyStringRef> {
629611
self.server_hostname.clone()
630612
}
631613

614+
#[pymethod]
615+
fn peer_certificate(
616+
&self,
617+
binary: OptionalArg<bool>,
618+
vm: &VirtualMachine,
619+
) -> PyResult<Option<PyObjectRef>> {
620+
let binary = binary.unwrap_or(false);
621+
let init_finished = unsafe { sys::SSL_is_init_finished(self.stream().ssl().as_ptr()) } != 0;
622+
if !init_finished {
623+
return Err(vm.new_value_error("handshake not done yet".to_owned()));
624+
}
625+
self.stream()
626+
.ssl()
627+
.peer_certificate()
628+
.map(|cert| cert_to_py(vm, &cert, binary))
629+
.transpose()
630+
}
631+
632632
#[pymethod]
633633
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
634634
// Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE
@@ -706,6 +706,78 @@ fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef {
706706
}
707707
}
708708

709+
fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
710+
if binary {
711+
cert.to_der()
712+
.map(|b| vm.ctx.new_bytes(b))
713+
.map_err(|e| convert_openssl_error(vm, e))
714+
} else {
715+
let dict = vm.ctx.new_dict();
716+
717+
let name_to_py = |name: &x509::X509NameRef| {
718+
name.entries()
719+
.map(|entry| {
720+
let txt = match obj2txt(entry.object(), false) {
721+
Some(s) => vm.new_str(s),
722+
None => vm.get_none(),
723+
};
724+
let data = vm.new_str(entry.data().as_utf8()?.to_owned());
725+
Ok(vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![txt, data])]))
726+
})
727+
.collect::<Result<_, _>>()
728+
.map(|list| vm.ctx.new_tuple(list))
729+
.map_err(|e| convert_openssl_error(vm, e))
730+
};
731+
732+
dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?;
733+
dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?;
734+
735+
let version = unsafe { sys::X509_get_version(cert.as_ptr()) };
736+
dict.set_item("version", vm.new_int(version), vm)?;
737+
738+
let serial_num = cert
739+
.serial_number()
740+
.to_bn()
741+
.and_then(|bn| bn.to_hex_str())
742+
.map_err(|e| convert_openssl_error(vm, e))?;
743+
dict.set_item("serialNumber", vm.new_str(serial_num.to_owned()), vm)?;
744+
745+
dict.set_item("notBefore", vm.new_str(cert.not_before().to_string()), vm)?;
746+
dict.set_item("notAfter", vm.new_str(cert.not_after().to_string()), vm)?;
747+
748+
if let Some(names) = cert.subject_alt_names() {
749+
let san = names
750+
.iter()
751+
.filter_map(|gen_name| {
752+
if let Some(email) = gen_name.email() {
753+
Some(vm.ctx.new_tuple(vec![
754+
vm.new_str("email".to_owned()),
755+
vm.new_str(email.to_owned()),
756+
]))
757+
} else if let Some(dnsname) = gen_name.dnsname() {
758+
Some(vm.ctx.new_tuple(vec![
759+
vm.new_str("DNS".to_owned()),
760+
vm.new_str(dnsname.to_owned()),
761+
]))
762+
} else if let Some(ip) = gen_name.ipaddress() {
763+
Some(vm.ctx.new_tuple(vec![
764+
vm.new_str("IP Address".to_owned()),
765+
vm.new_str(String::from_utf8_lossy(ip).into_owned()),
766+
]))
767+
} else {
768+
// TODO: convert every type of general name:
769+
// https://github.com/python/cpython/blob/3.6/Modules/_ssl.c#L1092-L1231
770+
None
771+
}
772+
})
773+
.collect();
774+
dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?;
775+
};
776+
777+
Ok(dict.into_object())
778+
}
779+
}
780+
709781
fn parse_version_info(mut n: i64) -> (u8, u8, u8, u8, u8) {
710782
let status = (n & 0xF) as u8;
711783
n >>= 4;

0 commit comments

Comments
 (0)