Skip to content

Commit 80dd6ff

Browse files
committed
Add _SSLSocket.peer_certificate
1 parent 64ed2b9 commit 80dd6ff

File tree

1 file changed

+100
-28
lines changed

1 file changed

+100
-28
lines changed

vm/src/stdlib/ssl.rs

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ use crate::exceptions::PyBaseExceptionRef;
33
use crate::function::OptionalArg;
44
use crate::obj::objbytearray::PyByteArrayRef;
55
use crate::obj::objbyteinner::PyBytesLike;
6-
use crate::obj::objbytes::PyBytesRef;
7-
use crate::obj::objstr::{PyString, PyStringRef};
6+
use crate::obj::objstr::PyStringRef;
87
use crate::obj::{objtype::PyClassRef, objweakref::PyWeak};
98
use crate::pyobject::{
10-
Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
9+
Either, IntoPyObject, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue,
1110
};
1211
use crate::types::create_type;
1312
use crate::VirtualMachine;
@@ -23,12 +22,12 @@ use openssl::{
2322
error::ErrorStack,
2423
nid::Nid,
2524
ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode},
26-
x509::{X509Ref, X509},
25+
x509::{self, X509Ref, X509},
2726
};
2827

2928
mod sys {
3029
#![allow(non_camel_case_types, unused)]
31-
use libc::{c_char, c_double, c_int, c_void};
30+
use libc::{c_char, c_double, c_int, c_long, c_void};
3231
pub use openssl_sys::*;
3332
extern "C" {
3433
pub fn OBJ_txt2obj(s: *const c_char, no_name: c_int) -> *mut ASN1_OBJECT;
@@ -45,6 +44,8 @@ mod sys {
4544
pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int;
4645
pub fn X509_STORE_get0_objects(ctx: *mut X509_STORE) -> *mut stack_st_X509_OBJECT;
4746
pub fn X509_OBJECT_free(a: *mut X509_OBJECT);
47+
pub fn SSL_is_init_finished(ssl: *const SSL) -> c_int;
48+
pub fn X509_get_version(x: *const X509) -> c_long;
4849
}
4950

5051
pub enum stack_st_X509_OBJECT {}
@@ -484,14 +485,7 @@ impl PySslContext {
484485
.iter()
485486
.filter_map(|cert| {
486487
let cert = cert.x509()?;
487-
let obj = if binary_form {
488-
cert.to_der()
489-
.map(|b| vm.ctx.new_bytes(b))
490-
.map_err(|e| convert_openssl_error(vm, e))
491-
} else {
492-
todo!()
493-
};
494-
Some(obj)
488+
Some(cert_to_py(vm, cert, binary_form))
495489
})
496490
.collect::<Result<Vec<_>, _>>()?;
497491
Ok(vm.ctx.new_list(certs))
@@ -503,18 +497,6 @@ impl PySslContext {
503497
args: WrapSocketArgs,
504498
vm: &VirtualMachine,
505499
) -> PyResult<PySslSocket> {
506-
let server_hostname = args
507-
.server_hostname
508-
.map(|s| {
509-
vm.encode(
510-
s.into_object(),
511-
Some(PyString::from("ascii").into_ref(vm)),
512-
None,
513-
)
514-
.and_then(|res| PyBytesRef::try_from_object(vm, res))
515-
})
516-
.transpose()?;
517-
518500
let ssl = {
519501
let ptr = zelf.ptr();
520502
let ctx = unsafe { ssl::SslContext::from_ptr(ptr) };
@@ -540,7 +522,7 @@ impl PySslContext {
540522
ctx: zelf,
541523
stream: RefCell::new(Some(stream)),
542524
socket_type,
543-
server_hostname,
525+
server_hostname: args.server_hostname,
544526
owner: RefCell::new(args.owner.as_ref().map(PyWeak::downgrade)),
545527
})
546528
}
@@ -576,7 +558,7 @@ struct PySslSocket {
576558
ctx: PyRef<PySslContext>,
577559
stream: RefCell<Option<ssl::SslStreamBuilder<PySocketRef>>>,
578560
socket_type: SslServerOrClient,
579-
server_hostname: Option<PyBytesRef>,
561+
server_hostname: Option<PyStringRef>,
580562
owner: RefCell<Option<PyWeak>>,
581563
}
582564

@@ -627,10 +609,28 @@ impl PySslSocket {
627609
self.ctx.clone()
628610
}
629611
#[pyproperty]
630-
fn server_hostname(&self) -> Option<PyBytesRef> {
612+
fn server_hostname(&self) -> Option<PyStringRef> {
631613
self.server_hostname.clone()
632614
}
633615

616+
#[pymethod]
617+
fn peer_certificate(
618+
&self,
619+
binary: OptionalArg<bool>,
620+
vm: &VirtualMachine,
621+
) -> PyResult<Option<PyObjectRef>> {
622+
let binary = binary.unwrap_or(false);
623+
let init_finished = unsafe { sys::SSL_is_init_finished(self.stream().ssl().as_ptr()) } != 0;
624+
if !init_finished {
625+
return Err(vm.new_value_error("handshake not done yet".to_owned()));
626+
}
627+
self.stream()
628+
.ssl()
629+
.peer_certificate()
630+
.map(|cert| cert_to_py(vm, &cert, binary))
631+
.transpose()
632+
}
633+
634634
#[pymethod]
635635
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
636636
// Either a stream builder or a mid-handshake stream from WANT_READ or WANT_WRITE
@@ -708,6 +708,78 @@ fn convert_ssl_error(vm: &VirtualMachine, e: ssl::Error) -> PyBaseExceptionRef {
708708
}
709709
}
710710

711+
fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
712+
if binary {
713+
cert.to_der()
714+
.map(|b| vm.ctx.new_bytes(b))
715+
.map_err(|e| convert_openssl_error(vm, e))
716+
} else {
717+
let dict = vm.ctx.new_dict();
718+
719+
let name_to_py = |name: &x509::X509NameRef| {
720+
name.entries()
721+
.map(|entry| {
722+
let txt = match obj2txt(entry.object(), false) {
723+
Some(s) => vm.new_str(s),
724+
None => vm.get_none(),
725+
};
726+
let data = vm.new_str(entry.data().as_utf8()?.to_owned());
727+
Ok(vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![txt, data])]))
728+
})
729+
.collect::<Result<_, _>>()
730+
.map(|list| vm.ctx.new_tuple(list))
731+
.map_err(|e| convert_openssl_error(vm, e))
732+
};
733+
734+
dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?;
735+
dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?;
736+
737+
let version = unsafe { sys::X509_get_version(cert.as_ptr()) };
738+
dict.set_item("version", vm.new_int(version), vm)?;
739+
740+
let serial_num = cert
741+
.serial_number()
742+
.to_bn()
743+
.and_then(|bn| bn.to_hex_str())
744+
.map_err(|e| convert_openssl_error(vm, e))?;
745+
dict.set_item("serialNumber", vm.new_str(serial_num.to_owned()), vm)?;
746+
747+
dict.set_item("notBefore", vm.new_str(cert.not_before().to_string()), vm)?;
748+
dict.set_item("notAfter", vm.new_str(cert.not_after().to_string()), vm)?;
749+
750+
if let Some(names) = cert.subject_alt_names() {
751+
let san = names
752+
.iter()
753+
.filter_map(|gen_name| {
754+
if let Some(email) = gen_name.email() {
755+
Some(vm.ctx.new_tuple(vec![
756+
vm.new_str("email".to_owned()),
757+
vm.new_str(email.to_owned()),
758+
]))
759+
} else if let Some(dnsname) = gen_name.dnsname() {
760+
Some(vm.ctx.new_tuple(vec![
761+
vm.new_str("DNS".to_owned()),
762+
vm.new_str(dnsname.to_owned()),
763+
]))
764+
} else if let Some(ip) = gen_name.ipaddress() {
765+
Some(vm.ctx.new_tuple(vec![
766+
vm.new_str("IP Address".to_owned()),
767+
vm.new_str(String::from_utf8_lossy(ip).into_owned()),
768+
]))
769+
} else {
770+
// TODO: convert every type of general name:
771+
// https://github.com/python/cpython/blob/3.6/Modules/_ssl.c#L1092-L1231
772+
None
773+
}
774+
})
775+
.collect();
776+
dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?;
777+
};
778+
779+
Ok(dict.into_object())
780+
}
781+
}
782+
711783
fn parse_version_info(mut n: i64) -> (u8, u8, u8, u8, u8) {
712784
let status = (n & 0xF) as u8;
713785
n >>= 4;

0 commit comments

Comments
 (0)