diff --git a/Xlib/display.py b/Xlib/display.py index ead0185c..5b8826d3 100644 --- a/Xlib/display.py +++ b/Xlib/display.py @@ -61,7 +61,7 @@ 'fontable': ('font', 'gc') } -class _BaseDisplay(protocol_display.Display): +class BaseDisplay(protocol_display.Display): # Implement a cache of atom names, used by Window objects when # dealing with some ICCCM properties not defined in Xlib.Xatom @@ -86,7 +86,7 @@ def get_atom(self, atomname, only_if_exists=0): class Display(object): def __init__(self, display = None): - self.display = _BaseDisplay(display) + self.display = BaseDisplay(display) # Create the keymap cache self._keymap_codes = [()] * 256 diff --git a/Xlib/error.py b/Xlib/error.py index cb6d0d07..e5cc71f1 100644 --- a/Xlib/error.py +++ b/Xlib/error.py @@ -52,6 +52,13 @@ def __init__(self, whom): def __str__(self): return 'Display connection closed by %s' % self.whom +class ConnectionTimeoutError(OSError): + def __init__(self, whom): + self.whom = whom + + def __str__(self): + return 'Timeout reached in %s' % self.whom + class XauthError(Exception): pass class XNoAuthError(Exception): pass diff --git a/Xlib/protocol/display.py b/Xlib/protocol/display.py index 56623c35..31c16e6a 100644 --- a/Xlib/protocol/display.py +++ b/Xlib/protocol/display.py @@ -75,18 +75,30 @@ def bytesview(data, offset=0, size=None): return buffer(data, offset, size) + class TimeoutError(OSError): + """ Timeout expired. """ + pass + + class Display(object): extension_major_opcodes = {} error_classes = error.xerror_class.copy() event_classes = event.event_class.copy() - def __init__(self, display = None): + _READ_MASK = select.POLLIN | select.POLLPRI + _ERROR_MASK = select.POLLERR | select.POLLHUP + _WRITE_MASK = select.POLLOUT + + _READ_POLL_MASK = _READ_MASK | _ERROR_MASK + _READY_POLL_MASK = _READ_MASK | _ERROR_MASK | _WRITE_MASK + + def __init__(self, display = None, timeout = None): name, protocol, host, displayno, screenno = connect.get_display(display) self.display_name = name self.default_screen = screenno - self.socket = connect.get_socket(name, protocol, host, displayno) + self.socket = connect.get_socket(name, protocol, host, displayno, timeout) auth_name, auth_data = connect.get_auth(self.socket, name, protocol, host, displayno) @@ -99,6 +111,12 @@ def __init__(self, display = None): self.socket_error_lock = lock.allocate_lock() self.socket_error = None + # Initialize read and ready polls + self.read_poll = select.poll() + self.read_poll.register(self.socket, self._READ_POLL_MASK) + self.ready_poll = select.poll() + self.ready_poll.register(self.socket, self._READY_POLL_MASK) + # Event queue self.event_queue_read_lock = lock.allocate_lock() self.event_queue_write_lock = lock.allocate_lock() @@ -367,7 +385,7 @@ def send_request(self, request, wait_for_response): # if qlen > 10: # self.flush() - def close_internal(self, whom): + def close_internal(self, whom, socket_error = error.ConnectionClosedError): # Clear out data structures self.request_queue = None self.sent_requests = None @@ -375,12 +393,24 @@ def close_internal(self, whom): self.data_send = None self.data_recv = None + for poll in (self.read_poll, self.ready_poll): + try: + poll.unregister(self.socket) + except (KeyError, ValueError): + # KeyError is raised if somehow the socket was not registered + # ValueError is raised if the socket's file descriptor is negative. + # In either case, we can't do anything better than to remove the reference to the poller. + pass + self.read_poll = None + self.ready_poll = None + # Close the connection self.socket.close() + self.socket = None # Set a connection closed indicator self.socket_error_lock.acquire() - self.socket_error = error.ConnectionClosedError(whom) + self.socket_error = socket_error(whom) self.socket_error_lock.release() @@ -537,31 +567,21 @@ def send_and_recv(self, flush = None, event = None, request = None, recv = None) if flush and flush_bytes is None: flush_bytes = self.data_sent_bytes + len(self.data_send) + # We're only checking for the socket to be writable + # if we're the sending thread. We always check for it + # to become readable: either we are the receiving thread + # and should take care of the data, or the receiving thread + # might finish receiving after having read the data - try: - # We're only checking for the socket to be writable - # if we're the sending thread. We always check for it - # to become readable: either we are the receiving thread - # and should take care of the data, or the receiving thread - # might finish receiving after having read the data - - if sending: - writeset = [self.socket] - else: - writeset = [] - - # Timeout immediately if we're only checking for - # something to read or if we're flushing, otherwise block - - if recv or flush: - timeout = 0 - else: - timeout = None - - rs, ws, es = select.select([self.socket], writeset, [], timeout) + # Timeout immediately if we're only checking for + # something to read or if we're flushing, otherwise block + if recv or flush: + timeout = 0 + else: + timeout = self.socket.gettimeout() - # Ignore errors caused by a signal received while blocking. - # All other errors are re-raised. + try: + rs, ws = self._select(sending, timeout) except select.error as err: if isinstance(err, OSError): code = err.errno @@ -610,6 +630,9 @@ def send_and_recv(self, flush = None, event = None, request = None, recv = None) self.data_recv = bytes(self.data_recv) + bytes_recv gotreq = self.parse_response(request) + if request == -1 and gotreq == -1: + self.close_internal('Xlib: Not a valid X11 server') + raise self.socket_error # Otherwise return, allowing the calling thread to figure # out if it has got the data it needs @@ -646,6 +669,11 @@ def send_and_recv(self, flush = None, event = None, request = None, recv = None) if recv: break + # We got timeout + if not ws and not rs: + self.close_internal('server', error.ConnectionTimeoutError) + raise self.socket_error + # Else there's may still data which must be sent, or # we haven't got the data we waited for. Lock and loop @@ -673,6 +701,52 @@ def send_and_recv(self, flush = None, event = None, request = None, recv = None) self.send_recv_lock.release() + def _select(self, sending, timeout): + ws = rs = False + if sending: + try: + ws, events = self._is_ready_for_command(timeout) + rs = self._check_can_read(events) + except TimeoutError: + ws = False + else: + try: + rs, _ = self._can_read(timeout) + except TimeoutError: + rs = False + return rs, ws + + def _check_can_read(self, events): + return bool(events and events[0][1] & self._READ_MASK) + + def _can_read(self, timeout): + """ + Return True if data is ready to be read from the socket, + otherwise False. + This doesn't guarantee that the socket is still connected, just + that there is data to read. + """ + if timeout is not None: + timeout = timeout * 1000 # timeout in poll is in milliseconds + events = self.read_poll.poll(timeout) + if not events: + raise TimeoutError('Timeout in read poll') + return self._check_can_read(events), events + + def _check_is_ready_for_command(self, events): + return bool(events and events[0][1] & self._WRITE_MASK) + + def _is_ready_for_command(self, timeout): + """ + Return True if the socket is ready to send a command, + otherwise False + """ + if timeout is not None: + timeout = timeout * 1000 # timeout in poll is in milliseconds + events = self.ready_poll.poll(timeout) + if not events: + raise TimeoutError('Timeout in ready poll') + return self._check_is_ready_for_command(events), events def parse_response(self, request): """Internal method. @@ -973,6 +1047,8 @@ def parse_connection_setup(self): r._data, d = r._reply.parse_binary(self.data_recv[:8], self, rawdict = 1) self.data_recv = self.data_recv[8:] + if r._data['status'] not in [0, 1, 2]: + return -1 # Loop around to see if we have got the additional data # already diff --git a/Xlib/support/connect.py b/Xlib/support/connect.py index 4db4c2f4..b213dce2 100644 --- a/Xlib/support/connect.py +++ b/Xlib/support/connect.py @@ -73,7 +73,7 @@ def get_display(display): return mod.get_display(display) -def get_socket(dname, protocol, host, dno): +def get_socket(dname, protocol, host, dno, timeout): """socket = get_socket(dname, protocol, host, dno) Connect to the display specified by DNAME, PROTOCOL, HOST and DNO, which @@ -84,7 +84,7 @@ def get_socket(dname, protocol, host, dno): modname = _socket_mods.get(platform, _default_socket_mod) mod = _relative_import(modname) - return mod.get_socket(dname, protocol, host, dno) + return mod.get_socket(dname, protocol, host, dno, timeout) def get_auth(sock, dname, protocol, host, dno): diff --git a/Xlib/support/unix_connect.py b/Xlib/support/unix_connect.py index c2261dae..cb02b27d 100644 --- a/Xlib/support/unix_connect.py +++ b/Xlib/support/unix_connect.py @@ -88,26 +88,30 @@ def get_display(display): return display, protocol, host, dno, screen -def _get_tcp_socket(host, dno): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect((host, 6000 + dno)) - return s +def _get_tcp_socket(host, dno, timeout): + return socket.create_connection( + address=(host, 6000 + dno), + timeout=timeout + ) + -def _get_unix_socket(address): +def _get_unix_socket(address, timeout): s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s.settimeout(timeout) s.connect(address) return s -def get_socket(dname, protocol, host, dno): + +def get_socket(dname, protocol, host, dno, timeout = None): assert protocol in SUPPORTED_PROTOCOLS try: # Darwin funky socket. if protocol == 'darwin': - s = _get_unix_socket(dname) + s = _get_unix_socket(dname, timeout) # TCP socket, note the special case: `unix:0.0` is equivalent to `:0.0`. elif (protocol is None or protocol != 'unix') and host and host != 'unix': - s = _get_tcp_socket(host, dno) + s = _get_tcp_socket(host, dno, timeout) # Unix socket. else: @@ -116,11 +120,11 @@ def get_socket(dname, protocol, host, dno): # Use abstract address. address = '\0' + address try: - s = _get_unix_socket(address) + s = _get_unix_socket(address, timeout) except socket.error: if not protocol and not host: # If no protocol/host was specified, fallback to TCP. - s = _get_tcp_socket(host, dno) + s = _get_tcp_socket(host, dno, timeout) else: raise except socket.error as val: diff --git a/Xlib/support/vms_connect.py b/Xlib/support/vms_connect.py index 3c53695f..bbe360e4 100644 --- a/Xlib/support/vms_connect.py +++ b/Xlib/support/vms_connect.py @@ -55,13 +55,15 @@ def get_display(display): return name, None, host, dno, screen -def get_socket(dname, protocol, host, dno): +def get_socket(dname, protocol, host, dno, timeout = None): try: # Always use TCP/IP sockets. Later it would be nice to # be able to use DECNET och LOCAL connections. - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect((host, 6000 + dno)) + s = socket.create_connection( + address=(host, 6000 + dno), + timeout=timeout + ) except socket.error as val: raise error.DisplayConnectionError(dname, str(val)) diff --git a/test/test_unix_connect.py b/test/test_unix_connect.py index 7680ba5b..a0432023 100644 --- a/test/test_unix_connect.py +++ b/test/test_unix_connect.py @@ -77,61 +77,61 @@ def fcntl(*args): for params, allow_unix, unix_addr_exists, allow_tcp, expect_connection_error, expected_calls in ( # Successful explicit TCP socket connection. (('tcp/host:6', None, 'host', 6), False, False, True, False, [ - ('_get_tcp_socket', 'host', 6), + ('_get_tcp_socket', 'host', 6, None), ]), # Failed explicit TCP socket connection. (('tcp/host:6', None, 'host', 6), False, False, False, True, [ - ('_get_tcp_socket', 'host', 6), + ('_get_tcp_socket', 'host', 6, None), ]), # Successful implicit TCP socket connection. (('host:5', None, 'host', 5), False, False, True, False, [ - ('_get_tcp_socket', 'host', 5), + ('_get_tcp_socket', 'host', 5, None), ]), # Failed implicit TCP socket connection. (('host:5', None, 'host', 5), False, False, False, True, [ - ('_get_tcp_socket', 'host', 5), + ('_get_tcp_socket', 'host', 5, None), ]), # Successful explicit Unix socket connection. (('unix/name:0', 'unix', 'name', 0), True, True, False, False, [ ('os.path.exists', '/tmp/.X11-unix/X0'), - ('_get_unix_socket', '/tmp/.X11-unix/X0'), + ('_get_unix_socket', '/tmp/.X11-unix/X0', None), ]), # Failed explicit Unix socket connection. (('unix/name:0', 'unix', 'name', 0), False, True, False, True, [ ('os.path.exists', '/tmp/.X11-unix/X0'), - ('_get_unix_socket', '/tmp/.X11-unix/X0'), + ('_get_unix_socket', '/tmp/.X11-unix/X0', None), ]), # Successful explicit Unix socket connection, variant. (('unix:0', None, 'unix', 0), True, True, False, False, [ ('os.path.exists', '/tmp/.X11-unix/X0'), - ('_get_unix_socket', '/tmp/.X11-unix/X0'), + ('_get_unix_socket', '/tmp/.X11-unix/X0', None), ]), # Failed explicit Unix socket connection, variant. (('unix:0', None, 'unix', 0), False, True, False, True, [ ('os.path.exists', '/tmp/.X11-unix/X0'), - ('_get_unix_socket', '/tmp/.X11-unix/X0'), + ('_get_unix_socket', '/tmp/.X11-unix/X0', None), ]), # Successful implicit Unix socket connection. ((':4', None, '', 4), True, True, False, False, [ ('os.path.exists', '/tmp/.X11-unix/X4'), - ('_get_unix_socket', '/tmp/.X11-unix/X4'), + ('_get_unix_socket', '/tmp/.X11-unix/X4', None), ]), # Successful implicit Unix socket connection, abstract address. ((':3', None, '', 3), True, False, False, False, [ ('os.path.exists', '/tmp/.X11-unix/X3'), - ('_get_unix_socket', '\0/tmp/.X11-unix/X3'), + ('_get_unix_socket', '\0/tmp/.X11-unix/X3', None), ]), # Failed implicit Unix socket connection, successful fallback on TCP. ((':2', None, '', 2), False, False, True, False, [ ('os.path.exists', '/tmp/.X11-unix/X2'), - ('_get_unix_socket', '\0/tmp/.X11-unix/X2'), - ('_get_tcp_socket', '', 2), + ('_get_unix_socket', '\0/tmp/.X11-unix/X2', None), + ('_get_tcp_socket', '', 2, None), ]), # Failed implicit Unix socket connection, failed fallback on TCP. ((':1', None, '', 1), False, False, False, True, [ ('os.path.exists', '/tmp/.X11-unix/X1'), - ('_get_unix_socket', '\0/tmp/.X11-unix/X1'), - ('_get_tcp_socket', '', 1), + ('_get_unix_socket', '\0/tmp/.X11-unix/X1', None), + ('_get_tcp_socket', '', 1, None), ]), ): with \