diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index 79efaa3639..d7afe1bd35 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -1,4 +1,5 @@ import socket +import os from testutils import assertRaises MESSAGE_A = b'aaaa' @@ -21,6 +22,18 @@ recv_b = connector.recv(len(MESSAGE_B)) assert recv_a == MESSAGE_A assert recv_b == MESSAGE_B + +# fileno +if os.name == "posix": + connector_fd = connector.fileno() + connection_fd = connection.fileno() + os.write(connector_fd, MESSAGE_A) + connection.send(MESSAGE_B) + recv_a = connection.recv(len(MESSAGE_A)) + recv_b = os.read(connector_fd, (len(MESSAGE_B))) + assert recv_a == MESSAGE_A + assert recv_b == MESSAGE_B + connection.close() connector.close() listener.close() diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index e893e0dab7..9d7a703e1f 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -86,6 +86,33 @@ impl Connection { _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } + + #[cfg(unix)] + fn fileno(&self) -> i64 { + use std::os::unix::io::AsRawFd; + let raw_fd = match self { + Connection::TcpListener(con) => con.as_raw_fd(), + Connection::UdpSocket(con) => con.as_raw_fd(), + Connection::TcpStream(con) => con.as_raw_fd(), + }; + raw_fd as i64 + } + + #[cfg(windows)] + fn fileno(&self) -> i64 { + use std::os::windows::io::AsRawSocket; + let raw_fd = match self { + Connection::TcpListener(con) => con.as_raw_socket(), + Connection::UdpSocket(con) => con.as_raw_socket(), + Connection::TcpStream(con) => con.as_raw_socket(), + }; + raw_fd as i64 + } + + #[cfg(all(not(unix), not(windows)))] + fn fileno(&self) -> i64 { + unimplemented!(); + } } impl Read for Connection { @@ -387,6 +414,18 @@ fn socket_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +fn socket_fileno(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, None)]); + + let socket = get_socket(zelf); + + let fileno = match socket.con.borrow_mut().as_mut() { + Some(v) => v.fileno(), + None => return Err(vm.new_type_error("".to_string())), + }; + Ok(vm.ctx.new_int(fileno)) +} + fn socket_getsockname(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); let socket = get_socket(zelf); @@ -424,6 +463,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "getsockname" => ctx.new_rustfunc(socket_getsockname), "sendto" => ctx.new_rustfunc(socket_sendto), "recvfrom" => ctx.new_rustfunc(socket_recvfrom), + "fileno" => ctx.new_rustfunc(socket_fileno), }); py_module!(vm, "socket", {