Skip to content

Commit 837bf59

Browse files
authored
Merge pull request #3147 from youknowone/ssl-module
Refactor ssl module
2 parents ec5b6b4 + b195b10 commit 837bf59

File tree

2 files changed

+223
-198
lines changed

2 files changed

+223
-198
lines changed

vm/src/stdlib/socket.rs

Lines changed: 112 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
exceptions::{IntoPyException, PyBaseExceptionRef},
66
function::{FuncArgs, OptionalArg, OptionalOption},
77
utils::{Either, ToCString},
8-
IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromBorrowedObject,
8+
IntoPyObject, PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromBorrowedObject,
99
TryFromObject, TypeProtocol, VirtualMachine,
1010
};
1111
use crossbeam_utils::atomic::AtomicCell;
@@ -18,7 +18,10 @@ use std::convert::TryFrom;
1818
use std::mem::MaybeUninit;
1919
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
2020
use std::time::{Duration, Instant};
21-
use std::{ffi, io};
21+
use std::{
22+
ffi,
23+
io::{self, Read, Write},
24+
};
2225

2326
#[cfg(unix)]
2427
type RawSocket = std::os::unix::io::RawFd;
@@ -158,13 +161,26 @@ impl Default for PySocket {
158161
}
159162
}
160163

161-
pub type PySocketRef = PyRef<PySocket>;
162-
163164
#[cfg(windows)]
164165
const CLOSED_ERR: i32 = c::WSAENOTSOCK;
165166
#[cfg(unix)]
166167
const CLOSED_ERR: i32 = c::EBADF;
167-
#[pyimpl(flags(BASETYPE))]
168+
169+
impl Read for &PySocket {
170+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
171+
(&mut &*self.sock_io()?).read(buf)
172+
}
173+
}
174+
impl Write for &PySocket {
175+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
176+
(&mut &*self.sock_io()?).write(buf)
177+
}
178+
179+
fn flush(&mut self) -> std::io::Result<()> {
180+
(&mut &*self.sock_io()?).flush()
181+
}
182+
}
183+
168184
impl PySocket {
169185
pub fn sock_opt(&self) -> Option<PyMappedRwLockReadGuard<'_, Socket>> {
170186
PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.get()).ok()
@@ -179,94 +195,6 @@ impl PySocket {
179195
self.sock_io().map_err(|e| e.into_pyexception(vm))
180196
}
181197

182-
#[pyslot]
183-
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
184-
Self::default().into_pyresult_with_type(vm, cls)
185-
}
186-
187-
#[pymethod(magic)]
188-
fn init(
189-
&self,
190-
family: OptionalArg<i32>,
191-
socket_kind: OptionalArg<i32>,
192-
proto: OptionalArg<i32>,
193-
fileno: OptionalOption<PyObjectRef>,
194-
vm: &VirtualMachine,
195-
) -> PyResult<()> {
196-
let mut family = family.unwrap_or(-1);
197-
let mut socket_kind = socket_kind.unwrap_or(-1);
198-
let mut proto = proto.unwrap_or(-1);
199-
let fileno = fileno
200-
.flatten()
201-
.map(|obj| get_raw_sock(obj, vm))
202-
.transpose()?;
203-
let sock;
204-
if let Some(fileno) = fileno {
205-
sock = sock_from_raw(fileno, vm)?;
206-
match sock.local_addr() {
207-
Ok(addr) if family == -1 => family = addr.family() as i32,
208-
Err(e)
209-
if family == -1
210-
|| matches!(
211-
e.raw_os_error(),
212-
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
213-
) =>
214-
{
215-
std::mem::forget(sock);
216-
return Err(e.into_pyexception(vm));
217-
}
218-
_ => {}
219-
}
220-
if socket_kind == -1 {
221-
// TODO: when socket2 cuts a new release, type will be available on all os
222-
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
223-
let res = unsafe {
224-
c::getsockopt(
225-
sock_fileno(&sock) as _,
226-
c::SOL_SOCKET,
227-
c::SO_TYPE,
228-
&mut socket_kind as *mut libc::c_int as *mut _,
229-
&mut (std::mem::size_of::<i32>() as _),
230-
)
231-
};
232-
if res < 0 {
233-
return Err(super::os::errno_err(vm));
234-
}
235-
}
236-
cfg_if::cfg_if! {
237-
if #[cfg(any(
238-
target_os = "android",
239-
target_os = "freebsd",
240-
target_os = "fuchsia",
241-
target_os = "linux",
242-
))] {
243-
if proto == -1 {
244-
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
245-
}
246-
} else {
247-
proto = 0;
248-
}
249-
}
250-
} else {
251-
if family == -1 {
252-
family = c::AF_INET as i32
253-
}
254-
if socket_kind == -1 {
255-
socket_kind = c::SOCK_STREAM
256-
}
257-
if proto == -1 {
258-
proto = 0
259-
}
260-
sock = Socket::new(
261-
Domain::from(family),
262-
SocketType::from(socket_kind),
263-
Some(Protocol::from(proto)),
264-
)
265-
.map_err(|err| err.into_pyexception(vm))?;
266-
};
267-
self.init_inner(family, socket_kind, proto, sock, vm)
268-
}
269-
270198
fn init_inner(
271199
&self,
272200
family: i32,
@@ -523,6 +451,97 @@ impl PySocket {
523451
Err(err.into())
524452
}
525453
}
454+
}
455+
456+
#[pyimpl(flags(BASETYPE))]
457+
impl PySocket {
458+
#[pyslot]
459+
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
460+
Self::default().into_pyresult_with_type(vm, cls)
461+
}
462+
463+
#[pymethod(magic)]
464+
fn init(
465+
&self,
466+
family: OptionalArg<i32>,
467+
socket_kind: OptionalArg<i32>,
468+
proto: OptionalArg<i32>,
469+
fileno: OptionalOption<PyObjectRef>,
470+
vm: &VirtualMachine,
471+
) -> PyResult<()> {
472+
let mut family = family.unwrap_or(-1);
473+
let mut socket_kind = socket_kind.unwrap_or(-1);
474+
let mut proto = proto.unwrap_or(-1);
475+
let fileno = fileno
476+
.flatten()
477+
.map(|obj| get_raw_sock(obj, vm))
478+
.transpose()?;
479+
let sock;
480+
if let Some(fileno) = fileno {
481+
sock = sock_from_raw(fileno, vm)?;
482+
match sock.local_addr() {
483+
Ok(addr) if family == -1 => family = addr.family() as i32,
484+
Err(e)
485+
if family == -1
486+
|| matches!(
487+
e.raw_os_error(),
488+
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
489+
) =>
490+
{
491+
std::mem::forget(sock);
492+
return Err(e.into_pyexception(vm));
493+
}
494+
_ => {}
495+
}
496+
if socket_kind == -1 {
497+
// TODO: when socket2 cuts a new release, type will be available on all os
498+
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
499+
let res = unsafe {
500+
c::getsockopt(
501+
sock_fileno(&sock) as _,
502+
c::SOL_SOCKET,
503+
c::SO_TYPE,
504+
&mut socket_kind as *mut libc::c_int as *mut _,
505+
&mut (std::mem::size_of::<i32>() as _),
506+
)
507+
};
508+
if res < 0 {
509+
return Err(super::os::errno_err(vm));
510+
}
511+
}
512+
cfg_if::cfg_if! {
513+
if #[cfg(any(
514+
target_os = "android",
515+
target_os = "freebsd",
516+
target_os = "fuchsia",
517+
target_os = "linux",
518+
))] {
519+
if proto == -1 {
520+
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
521+
}
522+
} else {
523+
proto = 0;
524+
}
525+
}
526+
} else {
527+
if family == -1 {
528+
family = c::AF_INET as i32
529+
}
530+
if socket_kind == -1 {
531+
socket_kind = c::SOCK_STREAM
532+
}
533+
if proto == -1 {
534+
proto = 0
535+
}
536+
sock = Socket::new(
537+
Domain::from(family),
538+
SocketType::from(socket_kind),
539+
Some(Protocol::from(proto)),
540+
)
541+
.map_err(|err| err.into_pyexception(vm))?;
542+
};
543+
self.init_inner(family, socket_kind, proto, sock, vm)
544+
}
526545

527546
#[pymethod]
528547
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
@@ -919,20 +938,6 @@ impl PySocket {
919938
}
920939
}
921940

922-
impl io::Read for PySocketRef {
923-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
924-
<&Socket as io::Read>::read(&mut &*self.sock_io()?, buf)
925-
}
926-
}
927-
impl io::Write for PySocketRef {
928-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
929-
<&Socket as io::Write>::write(&mut &*self.sock_io()?, buf)
930-
}
931-
fn flush(&mut self) -> io::Result<()> {
932-
<&Socket as io::Write>::flush(&mut &*self.sock_io()?)
933-
}
934-
}
935-
936941
struct Address {
937942
host: PyStrRef,
938943
port: u16,

0 commit comments

Comments
 (0)