Skip to content

Commit a78ead6

Browse files
committed
Refactor ssl module
1 parent b8d8930 commit a78ead6

File tree

1 file changed

+88
-89
lines changed

1 file changed

+88
-89
lines changed

vm/src/stdlib/ssl.rs

+88-89
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ fn nid2obj(nid: Nid) -> Option<Asn1Object> {
105105
unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) }
106106
}
107107
fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
108-
unsafe {
109-
let no_name = if no_name { 1 } else { 0 };
110-
let ptr = obj.as_ptr();
108+
let no_name = if no_name { 1 } else { 0 };
109+
let ptr = obj.as_ptr();
110+
let s = unsafe {
111111
let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name);
112112
assert!(buflen >= 0);
113113
if buflen == 0 {
@@ -116,10 +116,10 @@ fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
116116
let mut buf = vec![0u8; buflen as usize];
117117
let ret = sys::OBJ_obj2txt(buf.as_mut_ptr() as *mut libc::c_char, buflen, ptr, no_name);
118118
assert!(ret >= 0);
119-
let s = String::from_utf8(buf)
120-
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned());
121-
Some(s)
122-
}
119+
String::from_utf8(buf)
120+
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
121+
};
122+
Some(s)
123123
}
124124

125125
type PyNid = (libc::c_int, String, String, Option<String>);
@@ -232,9 +232,8 @@ fn _ssl_rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
232232
return Err(vm.new_value_error("num must be positive".to_owned()));
233233
}
234234
let mut buf = vec![0; n as usize];
235-
openssl::rand::rand_bytes(&mut buf)
236-
.map(|()| buf)
237-
.map_err(|e| convert_openssl_error(vm, e))
235+
openssl::rand::rand_bytes(&mut buf).map_err(|e| convert_openssl_error(vm, e))?;
236+
Ok(buf)
238237
}
239238

240239
fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool)> {
@@ -642,67 +641,70 @@ struct LoadCertChainArgs {
642641
password: Option<Either<PyStrRef, ArgCallable>>,
643642
}
644643

645-
struct SocketTimeout {
646-
// Err is true if the socket is blocking
647-
deadline: Result<Instant, bool>,
648-
}
649-
impl SocketTimeout {
650-
fn get(s: &SocketStream) -> Self {
651-
let deadline = s.0.get_timeout().map(|d| Instant::now() + d);
652-
Self { deadline }
653-
}
654-
}
644+
// Err is true if the socket is blocking
645+
type SocketDeadline = Result<Instant, bool>;
646+
655647
enum SelectRet {
656648
Nonblocking,
657649
TimedOut,
658650
IsBlocking,
659651
Closed,
660652
Ok,
661653
}
662-
fn ssl_select(sock: &SocketStream, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet {
663-
let sock = match sock.0.sock_opt() {
664-
Some(s) => s,
665-
None => return SelectRet::Closed,
666-
};
667-
let timeout = match &timeout.deadline {
668-
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
669-
Some(timeout) => timeout,
670-
None => return SelectRet::TimedOut,
671-
},
672-
Err(true) => return SelectRet::IsBlocking,
673-
Err(false) => return SelectRet::Nonblocking,
674-
};
675-
let res = socket::sock_select(
676-
&sock,
677-
match needs {
678-
SslNeeds::Read => socket::SelectKind::Read,
679-
SslNeeds::Write => socket::SelectKind::Write,
680-
},
681-
Some(timeout),
682-
);
683-
match res {
684-
Ok(true) => SelectRet::TimedOut,
685-
_ => SelectRet::Ok,
686-
}
687-
}
654+
688655
#[derive(Clone, Copy)]
689656
enum SslNeeds {
690657
Read,
691658
Write,
692659
}
693660

694-
fn socket_needs(
695-
err: &ssl::Error,
696-
sock: &SocketStream,
697-
timeout: &SocketTimeout,
698-
) -> (Option<SslNeeds>, SelectRet) {
699-
let needs = match err.code() {
700-
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
701-
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
702-
_ => None,
703-
};
704-
let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, timeout));
705-
(needs, state)
661+
struct SocketStream(PyRef<PySocket>);
662+
663+
impl SocketStream {
664+
fn timeout_deadline(&self) -> SocketDeadline {
665+
self.0.get_timeout().map(|d| Instant::now() + d)
666+
}
667+
668+
fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet {
669+
let sock = match self.0.sock_opt() {
670+
Some(s) => s,
671+
None => return SelectRet::Closed,
672+
};
673+
let deadline = match &deadline {
674+
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
675+
Some(deadline) => deadline,
676+
None => return SelectRet::TimedOut,
677+
},
678+
Err(true) => return SelectRet::IsBlocking,
679+
Err(false) => return SelectRet::Nonblocking,
680+
};
681+
let res = socket::sock_select(
682+
&sock,
683+
match needs {
684+
SslNeeds::Read => socket::SelectKind::Read,
685+
SslNeeds::Write => socket::SelectKind::Write,
686+
},
687+
Some(deadline),
688+
);
689+
match res {
690+
Ok(true) => SelectRet::TimedOut,
691+
_ => SelectRet::Ok,
692+
}
693+
}
694+
695+
fn socket_needs(
696+
&self,
697+
err: &ssl::Error,
698+
deadline: &SocketDeadline,
699+
) -> (Option<SslNeeds>, SelectRet) {
700+
let needs = match err.code() {
701+
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
702+
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
703+
_ => None,
704+
};
705+
let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline));
706+
(needs, state)
707+
}
706708
}
707709

708710
fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
@@ -788,38 +790,37 @@ impl PySslSocket {
788790
.map(cipher_to_tuple)
789791
}
790792

793+
#[cfg(osslconf = "OPENSSL_NO_COMP")]
791794
#[pymethod]
792795
fn compression(&self) -> Option<&'static str> {
793-
#[cfg(osslconf = "OPENSSL_NO_COMP")]
794-
{
795-
None
796+
None
797+
}
798+
#[cfg(not(osslconf = "OPENSSL_NO_COMP"))]
799+
#[pymethod]
800+
fn compression(&self) -> Option<&'static str> {
801+
let stream = self.stream.read();
802+
let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) };
803+
if comp_method.is_null() {
804+
return None;
796805
}
797-
#[cfg(not(osslconf = "OPENSSL_NO_COMP"))]
798-
{
799-
let stream = self.stream.read();
800-
let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) };
801-
if comp_method.is_null() {
802-
return None;
803-
}
804-
let typ = unsafe { sys::COMP_get_type(comp_method) };
805-
let nid = Nid::from_raw(typ);
806-
if nid == Nid::UNDEF {
807-
return None;
808-
}
809-
nid.short_name().ok()
806+
let typ = unsafe { sys::COMP_get_type(comp_method) };
807+
let nid = Nid::from_raw(typ);
808+
if nid == Nid::UNDEF {
809+
return None;
810810
}
811+
nid.short_name().ok()
811812
}
812813

813814
#[pymethod]
814815
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
815816
let mut stream = self.stream.write();
816-
let timeout = SocketTimeout::get(stream.get_ref());
817+
let timeout = stream.get_ref().timeout_deadline();
817818
loop {
818819
let err = match stream.do_handshake() {
819820
Ok(()) => return Ok(()),
820821
Err(e) => e,
821822
};
822-
let (needs, state) = socket_needs(&err, &stream.get_ref(), &timeout);
823+
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
823824
match state {
824825
SelectRet::TimedOut => {
825826
return Err(socket::timeout_error_msg(
@@ -844,8 +845,8 @@ impl PySslSocket {
844845
let mut stream = self.stream.write();
845846
let data = data.borrow_buf();
846847
let data = &*data;
847-
let timeout = SocketTimeout::get(stream.get_ref());
848-
let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout);
848+
let timeout = stream.get_ref().timeout_deadline();
849+
let state = stream.get_ref().select(SslNeeds::Write, &timeout);
849850
match state {
850851
SelectRet::TimedOut => {
851852
return Err(socket::timeout_error_msg(
@@ -861,7 +862,7 @@ impl PySslSocket {
861862
Ok(len) => return Ok(len),
862863
Err(e) => e,
863864
};
864-
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
865+
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
865866
match state {
866867
SelectRet::TimedOut => {
867868
return Err(socket::timeout_error_msg(
@@ -902,7 +903,7 @@ impl PySslSocket {
902903
Some(b) => b,
903904
None => buf,
904905
};
905-
let timeout = SocketTimeout::get(stream.get_ref());
906+
let timeout = stream.get_ref().timeout_deadline();
906907
let count = loop {
907908
let err = match stream.ssl_read(buf) {
908909
Ok(count) => break count,
@@ -913,7 +914,7 @@ impl PySslSocket {
913914
{
914915
break 0;
915916
}
916-
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
917+
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
917918
match state {
918919
SelectRet::TimedOut => {
919920
return Err(socket::timeout_error_msg(
@@ -996,10 +997,9 @@ fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple {
996997
}
997998

998999
fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
999-
if binary {
1000-
cert.to_der()
1001-
.map(|b| vm.ctx.new_bytes(b))
1002-
.map_err(|e| convert_openssl_error(vm, e))
1000+
let r = if binary {
1001+
let b = cert.to_der().map_err(|e| convert_openssl_error(vm, e))?;
1002+
vm.ctx.new_bytes(b)
10031003
} else {
10041004
let dict = vm.ctx.new_dict();
10051005

@@ -1073,8 +1073,9 @@ fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
10731073
dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?;
10741074
};
10751075

1076-
Ok(dict.into_object())
1077-
}
1076+
dict.into_object()
1077+
};
1078+
Ok(r)
10781079
}
10791080

10801081
#[allow(non_snake_case)]
@@ -1238,8 +1239,6 @@ fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) {
12381239
#[cfg(not(windows))]
12391240
fn extend_module_platform_specific(_module: &PyObjectRef, _vm: &VirtualMachine) {}
12401241

1241-
struct SocketStream(PyRef<PySocket>);
1242-
12431242
impl std::io::Read for SocketStream {
12441243
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
12451244
<&socket2::Socket as std::io::Read>::read(&mut &*self.0.sock_io()?, buf)

0 commit comments

Comments
 (0)