diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 9d7a703e1f..edab574b38 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -3,13 +3,11 @@ use std::io; use std::io::Read; use std::io::Write; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; -use std::ops::Deref; -use crate::function::PyFuncArgs; -use crate::obj::objbytes; -use crate::obj::objint; -use crate::obj::objsequence::get_elements; -use crate::obj::objstr; +use crate::obj::objbytes::PyBytesRef; +use crate::obj::objint::PyIntRef; +use crate::obj::objstr::PyStringRef; +use crate::obj::objtuple::PyTupleRef; use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; use crate::vm::VirtualMachine; @@ -161,283 +159,209 @@ impl Socket { } } -fn get_socket<'a>(obj: &'a PyObjectRef) -> impl Deref + 'a { - obj.payload::().unwrap() -} - type SocketRef = PyRef; -fn socket_new( - cls: PyClassRef, - family: AddressFamily, - kind: SocketKind, - vm: &VirtualMachine, -) -> PyResult { - Socket::new(family, kind).into_ref_with_type(vm, cls) -} +impl SocketRef { + fn new( + cls: PyClassRef, + family: AddressFamily, + kind: SocketKind, + vm: &VirtualMachine, + ) -> PyResult { + Socket::new(family, kind).into_ref_with_type(vm, cls) + } -fn socket_connect(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))] - ); - - let address_string = get_address_string(vm, address)?; - - let socket = get_socket(zelf); - - match socket.socket_kind { - SocketKind::Stream => match TcpStream::connect(address_string) { - Ok(stream) => { - socket - .con - .borrow_mut() - .replace(Connection::TcpStream(stream)); - Ok(vm.get_none()) - } - Err(s) => Err(vm.new_os_error(s.to_string())), - }, - SocketKind::Dgram => { - if let Some(Connection::UdpSocket(con)) = socket.con.borrow().as_ref() { - match con.connect(address_string) { - Ok(_) => Ok(vm.get_none()), - Err(s) => Err(vm.new_os_error(s.to_string())), + fn connect(self, address: Address, vm: &VirtualMachine) -> PyResult<()> { + let address_string = address.get_address_string(); + + match self.socket_kind { + SocketKind::Stream => match TcpStream::connect(address_string) { + Ok(stream) => { + self.con.borrow_mut().replace(Connection::TcpStream(stream)); + Ok(()) + } + Err(s) => Err(vm.new_os_error(s.to_string())), + }, + SocketKind::Dgram => { + if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() { + match con.connect(address_string) { + Ok(_) => Ok(()), + Err(s) => Err(vm.new_os_error(s.to_string())), + } + } else { + Err(vm.new_type_error("".to_string())) } - } else { - Err(vm.new_type_error("".to_string())) } } } -} -fn socket_bind(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))] - ); - - let address_string = get_address_string(vm, address)?; - - let socket = get_socket(zelf); - - match socket.socket_kind { - SocketKind::Stream => match TcpListener::bind(address_string) { - Ok(stream) => { - socket - .con - .borrow_mut() - .replace(Connection::TcpListener(stream)); - Ok(vm.get_none()) - } - Err(s) => Err(vm.new_os_error(s.to_string())), - }, - SocketKind::Dgram => match UdpSocket::bind(address_string) { - Ok(dgram) => { - socket - .con - .borrow_mut() - .replace(Connection::UdpSocket(dgram)); - Ok(vm.get_none()) - } - Err(s) => Err(vm.new_os_error(s.to_string())), - }, - } -} + fn bind(self, address: Address, vm: &VirtualMachine) -> PyResult<()> { + let address_string = address.get_address_string(); -fn get_address_string(vm: &VirtualMachine, address: &PyObjectRef) -> Result { - let args = PyFuncArgs { - args: get_elements(address).to_vec(), - kwargs: vec![], - }; - arg_check!( - vm, - args, - required = [ - (host, Some(vm.ctx.str_type())), - (port, Some(vm.ctx.int_type())) - ] - ); - - Ok(format!( - "{}:{}", - objstr::get_value(host), - objint::get_value(port).to_string() - )) -} + match self.socket_kind { + SocketKind::Stream => match TcpListener::bind(address_string) { + Ok(stream) => { + self.con + .borrow_mut() + .replace(Connection::TcpListener(stream)); + Ok(()) + } + Err(s) => Err(vm.new_os_error(s.to_string())), + }, + SocketKind::Dgram => match UdpSocket::bind(address_string) { + Ok(dgram) => { + self.con.borrow_mut().replace(Connection::UdpSocket(dgram)); + Ok(()) + } + Err(s) => Err(vm.new_os_error(s.to_string())), + }, + } + } -fn socket_listen(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(_zelf, None), (_num, Some(vm.ctx.int_type()))] - ); - Ok(vm.get_none()) -} + fn listen(self, _num: PyIntRef, _vm: &VirtualMachine) -> () {} -fn socket_accept(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, None)]); + fn accept(self, vm: &VirtualMachine) -> PyResult { + let ret = match self.con.borrow_mut().as_mut() { + Some(v) => v.accept(), + None => return Err(vm.new_type_error("".to_string())), + }; - let socket = get_socket(zelf); + let (tcp_stream, addr) = match ret { + Ok((socket, addr)) => (socket, addr), + Err(s) => return Err(vm.new_os_error(s.to_string())), + }; - let ret = match socket.con.borrow_mut().as_mut() { - Some(v) => v.accept(), - None => return Err(vm.new_type_error("".to_string())), - }; + let socket = Socket { + address_family: self.address_family, + socket_kind: self.socket_kind, + con: RefCell::new(Some(Connection::TcpStream(tcp_stream))), + } + .into_ref(vm); - let (tcp_stream, addr) = match ret { - Ok((socket, addr)) => (socket, addr), - Err(s) => return Err(vm.new_os_error(s.to_string())), - }; + let addr_tuple = get_addr_tuple(vm, addr)?; - let socket = Socket { - address_family: socket.address_family, - socket_kind: socket.socket_kind, - con: RefCell::new(Some(Connection::TcpStream(tcp_stream))), + Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple])) } - .into_ref(vm); - let addr_tuple = get_addr_tuple(vm, addr)?; + fn recv(self, bufsize: PyIntRef, vm: &VirtualMachine) -> PyResult { + let mut buffer = vec![0u8; bufsize.as_bigint().to_usize().unwrap()]; + match self.con.borrow_mut().as_mut() { + Some(v) => match v.read_exact(&mut buffer) { + Ok(_) => (), + Err(s) => return Err(vm.new_os_error(s.to_string())), + }, + None => return Err(vm.new_type_error("".to_string())), + }; + Ok(vm.ctx.new_bytes(buffer)) + } - Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple])) -} + fn recvfrom(self, bufsize: PyIntRef, vm: &VirtualMachine) -> PyResult { + let mut buffer = vec![0u8; bufsize.as_bigint().to_usize().unwrap()]; + let ret = match self.con.borrow().as_ref() { + Some(v) => v.recv_from(&mut buffer), + None => return Err(vm.new_type_error("".to_string())), + }; -fn socket_recv(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] - ); - let socket = get_socket(zelf); - - let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()]; - match socket.con.borrow_mut().as_mut() { - Some(v) => match v.read_exact(&mut buffer) { - Ok(_) => (), + let addr = match ret { + Ok((_size, addr)) => addr, Err(s) => return Err(vm.new_os_error(s.to_string())), - }, - None => return Err(vm.new_type_error("".to_string())), - }; - Ok(vm.ctx.new_bytes(buffer)) -} - -fn socket_recvfrom(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] - ); - - let socket = get_socket(zelf); - - let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()]; - let ret = match socket.con.borrow().as_ref() { - Some(v) => v.recv_from(&mut buffer), - None => return Err(vm.new_type_error("".to_string())), - }; + }; - let addr = match ret { - Ok((_size, addr)) => addr, - Err(s) => return Err(vm.new_os_error(s.to_string())), - }; + let addr_tuple = get_addr_tuple(vm, addr)?; - let addr_tuple = get_addr_tuple(vm, addr)?; + Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple])) + } - Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple])) -} + fn send(self, bytes: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> { + match self.con.borrow_mut().as_mut() { + Some(v) => match v.write(&bytes) { + Ok(_) => (), + Err(s) => return Err(vm.new_os_error(s.to_string())), + }, + None => return Err(vm.new_type_error("".to_string())), + }; + Ok(()) + } -fn socket_send(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))] - ); - let socket = get_socket(zelf); - - match socket.con.borrow_mut().as_mut() { - Some(v) => match v.write(&objbytes::get_value(&bytes)) { - Ok(_) => (), - Err(s) => return Err(vm.new_os_error(s.to_string())), - }, - None => return Err(vm.new_type_error("".to_string())), - }; - Ok(vm.get_none()) -} + fn sendto(self, bytes: PyBytesRef, address: Address, vm: &VirtualMachine) -> PyResult<()> { + let address_string = address.get_address_string(); -fn socket_sendto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [ - (zelf, None), - (bytes, Some(vm.ctx.bytes_type())), - (address, Some(vm.ctx.tuple_type())) - ] - ); - let address_string = get_address_string(vm, address)?; - - let socket = get_socket(zelf); - - match socket.socket_kind { - SocketKind::Dgram => { - if let Some(v) = socket.con.borrow().as_ref() { - return match v.send_to(&objbytes::get_value(&bytes), address_string) { - Ok(_) => Ok(vm.get_none()), - Err(s) => Err(vm.new_os_error(s.to_string())), - }; - } - // Doing implicit bind - match UdpSocket::bind("0.0.0.0:0") { - Ok(dgram) => match dgram.send_to(&objbytes::get_value(&bytes), address_string) { - Ok(_) => { - socket - .con - .borrow_mut() - .replace(Connection::UdpSocket(dgram)); - Ok(vm.get_none()) - } + match self.socket_kind { + SocketKind::Dgram => { + if let Some(v) = self.con.borrow().as_ref() { + return match v.send_to(&bytes, address_string) { + Ok(_) => Ok(()), + Err(s) => Err(vm.new_os_error(s.to_string())), + }; + } + // Doing implicit bind + match UdpSocket::bind("0.0.0.0:0") { + Ok(dgram) => match dgram.send_to(&bytes, address_string) { + Ok(_) => { + self.con.borrow_mut().replace(Connection::UdpSocket(dgram)); + Ok(()) + } + Err(s) => Err(vm.new_os_error(s.to_string())), + }, Err(s) => Err(vm.new_os_error(s.to_string())), - }, - Err(s) => Err(vm.new_os_error(s.to_string())), + } } + _ => Err(vm.new_not_implemented_error("".to_string())), } - _ => Err(vm.new_not_implemented_error("".to_string())), } -} -fn socket_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, None)]); - - let socket = get_socket(zelf); - socket.con.borrow_mut().take(); - Ok(vm.get_none()) -} + fn close(self, _vm: &VirtualMachine) -> () { + self.con.borrow_mut().take(); + } -fn socket_fileno(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, None)]); + fn fileno(self, vm: &VirtualMachine) -> PyResult { + let fileno = match self.con.borrow_mut().as_mut() { + Some(v) => v.fileno(), + None => return Err(vm.new_type_error("".to_string())), + }; + Ok(vm.ctx.new_int(fileno)) + } - let socket = get_socket(zelf); + fn getsockname(self, vm: &VirtualMachine) -> PyResult { + let addr = match self.con.borrow().as_ref() { + Some(v) => v.local_addr(), + None => return Err(vm.new_type_error("".to_string())), + }; - let fileno = match socket.con.borrow_mut().as_mut() { - Some(v) => v.fileno(), - None => return Err(vm.new_type_error("".to_string())), - }; - Ok(vm.ctx.new_int(fileno)) + match addr { + Ok(addr) => get_addr_tuple(vm, addr), + Err(s) => Err(vm.new_os_error(s.to_string())), + } + } } -fn socket_getsockname(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(zelf, None)]); - let socket = get_socket(zelf); +struct Address { + host: String, + port: usize, +} - let addr = match socket.con.borrow().as_ref() { - Some(v) => v.local_addr(), - None => return Err(vm.new_type_error("".to_string())), - }; +impl Address { + fn get_address_string(self) -> String { + format!("{}:{}", self.host, self.port.to_string()) + } +} - match addr { - Ok(addr) => get_addr_tuple(vm, addr), - Err(s) => Err(vm.new_os_error(s.to_string())), +impl TryFromObject for Address { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let tuple = PyTupleRef::try_from_object(vm, obj)?; + if tuple.elements.borrow().len() != 2 { + Err(vm.new_type_error("Address tuple should have only 2 values".to_string())) + } else { + Ok(Address { + host: PyStringRef::try_from_object(vm, tuple.elements.borrow()[0].clone())? + .value + .to_string(), + port: PyIntRef::try_from_object(vm, tuple.elements.borrow()[1].clone())? + .as_bigint() + .to_usize() + .unwrap(), + }) + } } } @@ -452,18 +376,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; let socket = py_class!(ctx, "socket", ctx.object(), { - "__new__" => ctx.new_rustfunc(socket_new), - "connect" => ctx.new_rustfunc(socket_connect), - "recv" => ctx.new_rustfunc(socket_recv), - "send" => ctx.new_rustfunc(socket_send), - "bind" => ctx.new_rustfunc(socket_bind), - "accept" => ctx.new_rustfunc(socket_accept), - "listen" => ctx.new_rustfunc(socket_listen), - "close" => ctx.new_rustfunc(socket_close), - "getsockname" => ctx.new_rustfunc(socket_getsockname), - "sendto" => ctx.new_rustfunc(socket_sendto), - "recvfrom" => ctx.new_rustfunc(socket_recvfrom), - "fileno" => ctx.new_rustfunc(socket_fileno), + "__new__" => ctx.new_rustfunc(SocketRef::new), + "connect" => ctx.new_rustfunc(SocketRef::connect), + "recv" => ctx.new_rustfunc(SocketRef::recv), + "send" => ctx.new_rustfunc(SocketRef::send), + "bind" => ctx.new_rustfunc(SocketRef::bind), + "accept" => ctx.new_rustfunc(SocketRef::accept), + "listen" => ctx.new_rustfunc(SocketRef::listen), + "close" => ctx.new_rustfunc(SocketRef::close), + "getsockname" => ctx.new_rustfunc(SocketRef::getsockname), + "sendto" => ctx.new_rustfunc(SocketRef::sendto), + "recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom), + "fileno" => ctx.new_rustfunc(SocketRef::fileno), }); py_module!(vm, "socket", {