Skip to content

Commit e0ba09d

Browse files
committed
Refactor ssl module
1 parent ec5b6b4 commit e0ba09d

File tree

2 files changed

+212
-194
lines changed

2 files changed

+212
-194
lines changed

vm/src/stdlib/socket.rs

Lines changed: 105 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
exceptions::{IntoPyException, PyBaseExceptionRef},
66
function::{FuncArgs, OptionalArg, OptionalOption},
77
utils::{Either, ToCString},
8-
IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromBorrowedObject,
8+
IntoPyObject, PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromBorrowedObject,
99
TryFromObject, TypeProtocol, VirtualMachine,
1010
};
1111
use crossbeam_utils::atomic::AtomicCell;
@@ -18,7 +18,10 @@ use std::convert::TryFrom;
1818
use std::mem::MaybeUninit;
1919
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
2020
use std::time::{Duration, Instant};
21-
use std::{ffi, io};
21+
use std::{
22+
ffi,
23+
io::{self, Read, Write},
24+
};
2225

2326
#[cfg(unix)]
2427
type RawSocket = std::os::unix::io::RawFd;
@@ -158,13 +161,11 @@ impl Default for PySocket {
158161
}
159162
}
160163

161-
pub type PySocketRef = PyRef<PySocket>;
162-
163164
#[cfg(windows)]
164165
const CLOSED_ERR: i32 = c::WSAENOTSOCK;
165166
#[cfg(unix)]
166167
const CLOSED_ERR: i32 = c::EBADF;
167-
#[pyimpl(flags(BASETYPE))]
168+
168169
impl PySocket {
169170
pub fn sock_opt(&self) -> Option<PyMappedRwLockReadGuard<'_, Socket>> {
170171
PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.get()).ok()
@@ -179,92 +180,16 @@ impl PySocket {
179180
self.sock_io().map_err(|e| e.into_pyexception(vm))
180181
}
181182

182-
#[pyslot]
183-
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
184-
Self::default().into_pyresult_with_type(vm, cls)
183+
pub fn read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
184+
(&mut &*self.sock_io()?).read(buf)
185185
}
186186

187-
#[pymethod(magic)]
188-
fn init(
189-
&self,
190-
family: OptionalArg<i32>,
191-
socket_kind: OptionalArg<i32>,
192-
proto: OptionalArg<i32>,
193-
fileno: OptionalOption<PyObjectRef>,
194-
vm: &VirtualMachine,
195-
) -> PyResult<()> {
196-
let mut family = family.unwrap_or(-1);
197-
let mut socket_kind = socket_kind.unwrap_or(-1);
198-
let mut proto = proto.unwrap_or(-1);
199-
let fileno = fileno
200-
.flatten()
201-
.map(|obj| get_raw_sock(obj, vm))
202-
.transpose()?;
203-
let sock;
204-
if let Some(fileno) = fileno {
205-
sock = sock_from_raw(fileno, vm)?;
206-
match sock.local_addr() {
207-
Ok(addr) if family == -1 => family = addr.family() as i32,
208-
Err(e)
209-
if family == -1
210-
|| matches!(
211-
e.raw_os_error(),
212-
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
213-
) =>
214-
{
215-
std::mem::forget(sock);
216-
return Err(e.into_pyexception(vm));
217-
}
218-
_ => {}
219-
}
220-
if socket_kind == -1 {
221-
// TODO: when socket2 cuts a new release, type will be available on all os
222-
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
223-
let res = unsafe {
224-
c::getsockopt(
225-
sock_fileno(&sock) as _,
226-
c::SOL_SOCKET,
227-
c::SO_TYPE,
228-
&mut socket_kind as *mut libc::c_int as *mut _,
229-
&mut (std::mem::size_of::<i32>() as _),
230-
)
231-
};
232-
if res < 0 {
233-
return Err(super::os::errno_err(vm));
234-
}
235-
}
236-
cfg_if::cfg_if! {
237-
if #[cfg(any(
238-
target_os = "android",
239-
target_os = "freebsd",
240-
target_os = "fuchsia",
241-
target_os = "linux",
242-
))] {
243-
if proto == -1 {
244-
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
245-
}
246-
} else {
247-
proto = 0;
248-
}
249-
}
250-
} else {
251-
if family == -1 {
252-
family = c::AF_INET as i32
253-
}
254-
if socket_kind == -1 {
255-
socket_kind = c::SOCK_STREAM
256-
}
257-
if proto == -1 {
258-
proto = 0
259-
}
260-
sock = Socket::new(
261-
Domain::from(family),
262-
SocketType::from(socket_kind),
263-
Some(Protocol::from(proto)),
264-
)
265-
.map_err(|err| err.into_pyexception(vm))?;
266-
};
267-
self.init_inner(family, socket_kind, proto, sock, vm)
187+
pub fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
188+
(&mut &*self.sock_io()?).write(buf)
189+
}
190+
191+
pub fn flush(&self) -> std::io::Result<()> {
192+
(&mut &*self.sock_io()?).flush()
268193
}
269194

270195
fn init_inner(
@@ -523,6 +448,97 @@ impl PySocket {
523448
Err(err.into())
524449
}
525450
}
451+
}
452+
453+
#[pyimpl(flags(BASETYPE))]
454+
impl PySocket {
455+
#[pyslot]
456+
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
457+
Self::default().into_pyresult_with_type(vm, cls)
458+
}
459+
460+
#[pymethod(magic)]
461+
fn init(
462+
&self,
463+
family: OptionalArg<i32>,
464+
socket_kind: OptionalArg<i32>,
465+
proto: OptionalArg<i32>,
466+
fileno: OptionalOption<PyObjectRef>,
467+
vm: &VirtualMachine,
468+
) -> PyResult<()> {
469+
let mut family = family.unwrap_or(-1);
470+
let mut socket_kind = socket_kind.unwrap_or(-1);
471+
let mut proto = proto.unwrap_or(-1);
472+
let fileno = fileno
473+
.flatten()
474+
.map(|obj| get_raw_sock(obj, vm))
475+
.transpose()?;
476+
let sock;
477+
if let Some(fileno) = fileno {
478+
sock = sock_from_raw(fileno, vm)?;
479+
match sock.local_addr() {
480+
Ok(addr) if family == -1 => family = addr.family() as i32,
481+
Err(e)
482+
if family == -1
483+
|| matches!(
484+
e.raw_os_error(),
485+
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
486+
) =>
487+
{
488+
std::mem::forget(sock);
489+
return Err(e.into_pyexception(vm));
490+
}
491+
_ => {}
492+
}
493+
if socket_kind == -1 {
494+
// TODO: when socket2 cuts a new release, type will be available on all os
495+
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
496+
let res = unsafe {
497+
c::getsockopt(
498+
sock_fileno(&sock) as _,
499+
c::SOL_SOCKET,
500+
c::SO_TYPE,
501+
&mut socket_kind as *mut libc::c_int as *mut _,
502+
&mut (std::mem::size_of::<i32>() as _),
503+
)
504+
};
505+
if res < 0 {
506+
return Err(super::os::errno_err(vm));
507+
}
508+
}
509+
cfg_if::cfg_if! {
510+
if #[cfg(any(
511+
target_os = "android",
512+
target_os = "freebsd",
513+
target_os = "fuchsia",
514+
target_os = "linux",
515+
))] {
516+
if proto == -1 {
517+
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
518+
}
519+
} else {
520+
proto = 0;
521+
}
522+
}
523+
} else {
524+
if family == -1 {
525+
family = c::AF_INET as i32
526+
}
527+
if socket_kind == -1 {
528+
socket_kind = c::SOCK_STREAM
529+
}
530+
if proto == -1 {
531+
proto = 0
532+
}
533+
sock = Socket::new(
534+
Domain::from(family),
535+
SocketType::from(socket_kind),
536+
Some(Protocol::from(proto)),
537+
)
538+
.map_err(|err| err.into_pyexception(vm))?;
539+
};
540+
self.init_inner(family, socket_kind, proto, sock, vm)
541+
}
526542

527543
#[pymethod]
528544
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
@@ -919,20 +935,6 @@ impl PySocket {
919935
}
920936
}
921937

922-
impl io::Read for PySocketRef {
923-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
924-
<&Socket as io::Read>::read(&mut &*self.sock_io()?, buf)
925-
}
926-
}
927-
impl io::Write for PySocketRef {
928-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
929-
<&Socket as io::Write>::write(&mut &*self.sock_io()?, buf)
930-
}
931-
fn flush(&mut self) -> io::Result<()> {
932-
<&Socket as io::Write>::flush(&mut &*self.sock_io()?)
933-
}
934-
}
935-
936938
struct Address {
937939
host: PyStrRef,
938940
port: u16,

0 commit comments

Comments
 (0)