1
1
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
2
2
# and that an immediate write/send/read/recv does the right thing
3
3
4
+ import unittest
4
5
import errno
5
6
import select
6
7
import socket
7
8
import ssl
8
9
9
10
# only mbedTLS supports non-blocking mode
10
- if not hasattr (ssl , "MBEDTLS_VERSION" ):
11
- print ("SKIP" )
12
- raise SystemExit
11
+ ssl_supports_nonblocking = hasattr (ssl , "MBEDTLS_VERSION" )
13
12
14
13
15
14
# get the name of an errno error code
@@ -24,34 +23,43 @@ def errno_name(er):
24
23
# do_connect establishes the socket and wraps it if tls is True.
25
24
# If handshake is true, the initial connect (and TLS handshake) is
26
25
# allowed to be performed before returning.
27
- def do_connect (peer_addr , tls , handshake ):
26
+ def do_connect (self , peer_addr , tls , handshake ):
28
27
s = socket .socket ()
29
28
s .setblocking (False )
30
29
try :
31
- # print("Connecting to", peer_addr)
30
+ print ("Connecting to" , peer_addr )
32
31
s .connect (peer_addr )
32
+ self .fail ()
33
33
except OSError as er :
34
34
print ("connect:" , errno_name (er .errno ))
35
+ self .assertEqual (er .errno , errno .EINPROGRESS )
36
+
35
37
# wrap with ssl/tls if desired
36
38
if tls :
39
+ print ("wrap socket" )
37
40
ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
38
- try :
39
- s = ssl_context .wrap_socket (s , do_handshake_on_connect = handshake )
40
- print ("wrap ok: True" )
41
- except Exception as e :
42
- print ("wrap er:" , e )
41
+ s = ssl_context .wrap_socket (s , do_handshake_on_connect = handshake )
42
+
43
43
return s
44
44
45
45
46
- # poll a socket and print out the result
47
- def poll (s ):
46
+ # poll a socket and check the result
47
+ def poll (self , s , expect_writable ):
48
48
poller = select .poll ()
49
49
poller .register (s )
50
- print ("poll: " , poller .poll (0 ))
50
+ result = poller .poll (0 )
51
+ print ("poll:" , result )
52
+ if expect_writable :
53
+ self .assertEqual (len (result ), 1 )
54
+ self .assertEqual (result [0 ][1 ], select .POLLOUT )
55
+ else :
56
+ self .assertEqual (result , [])
51
57
52
58
53
- # test runs the test against a specific peer address.
54
- def test (peer_addr , tls , handshake ):
59
+ # do_test runs the test against a specific peer address.
60
+ def do_test (self , peer_addr , tls , handshake ):
61
+ print ()
62
+
55
63
# MicroPython plain and TLS sockets have read/write
56
64
hasRW = True
57
65
@@ -62,54 +70,70 @@ def test(peer_addr, tls, handshake):
62
70
# connect + send
63
71
# non-blocking send should raise EAGAIN
64
72
if hasSR :
65
- s = do_connect (peer_addr , tls , handshake )
66
- poll (s )
73
+ s = do_connect (self , peer_addr , tls , handshake )
74
+ poll (self , s , False )
67
75
try :
68
76
ret = s .send (b"1234" )
69
- print ( "send ok:" , ret ) # shouldn't get here
77
+ self . fail ()
70
78
except OSError as er :
71
- print ("send er:" , errno_name (er .errno ))
79
+ print ("send error:" , errno_name (er .errno ))
80
+ self .assertEqual (er .errno , errno .EAGAIN )
72
81
s .close ()
73
82
74
83
# connect + write
75
84
# non-blocking write should return None
76
85
if hasRW :
77
- s = do_connect (peer_addr , tls , handshake )
78
- poll (s )
86
+ s = do_connect (self , peer_addr , tls , handshake )
87
+ poll (self , s , tls and handshake )
79
88
ret = s .write (b"1234" )
80
- print ("write: " , ret )
89
+ print ("write:" , ret )
90
+ if tls and handshake :
91
+ self .assertEqual (ret , 4 )
92
+ else :
93
+ self .assertIsNone (ret )
81
94
s .close ()
82
95
83
96
# connect + recv
84
97
# non-blocking recv should raise EAGAIN
85
98
if hasSR :
86
- s = do_connect (peer_addr , tls , handshake )
87
- poll (s )
99
+ s = do_connect (self , peer_addr , tls , handshake )
100
+ poll (self , s , False )
88
101
try :
89
102
ret = s .recv (10 )
90
- print ( "recv ok:" , ret ) # shouldn't get here
103
+ self . fail ()
91
104
except OSError as er :
92
- print ("recv er:" , errno_name (er .errno ))
105
+ print ("recv error:" , errno_name (er .errno ))
106
+ self .assertEqual (er .errno , errno .EAGAIN )
93
107
s .close ()
94
108
95
109
# connect + read
96
110
# non-blocking read should return None
97
111
if hasRW :
98
- s = do_connect (peer_addr , tls , handshake )
99
- poll (s )
112
+ s = do_connect (self , peer_addr , tls , handshake )
113
+ poll (self , s , tls and handshake )
100
114
ret = s .read (10 )
101
- print ("read: " , ret )
115
+ print ("read:" , ret )
116
+ self .assertIsNone (ret )
102
117
s .close ()
103
118
104
119
105
- if __name__ == "__main__" :
120
+ class Test ( unittest . TestCase ) :
106
121
# these tests use a non-existent test IP address, this way the connect takes forever and
107
122
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
108
- print ("--- Plain sockets to nowhere ---" )
109
- test (socket .getaddrinfo ("192.0.2.1" , 80 )[0 ][- 1 ], False , False )
110
- print ("--- SSL sockets to nowhere ---" )
111
- test (socket .getaddrinfo ("192.0.2.1" , 443 )[0 ][- 1 ], True , False )
112
- print ("--- Plain sockets ---" )
113
- test (socket .getaddrinfo ("micropython.org" , 80 )[0 ][- 1 ], False , False )
114
- print ("--- SSL sockets ---" )
115
- test (socket .getaddrinfo ("micropython.org" , 443 )[0 ][- 1 ], True , True )
123
+ def test_plain_sockets_to_nowhere (self ):
124
+ do_test (self , socket .getaddrinfo ("192.0.2.1" , 80 )[0 ][- 1 ], False , False )
125
+
126
+ @unittest .skipIf (not ssl_supports_nonblocking , "SSL doesn't support non-blocking" )
127
+ def test_ssl_sockets_to_nowhere (self ):
128
+ do_test (self , socket .getaddrinfo ("192.0.2.1" , 443 )[0 ][- 1 ], True , False )
129
+
130
+ def test_plain_sockets (self ):
131
+ do_test (self , socket .getaddrinfo ("micropython.org" , 80 )[0 ][- 1 ], False , False )
132
+
133
+ @unittest .skipIf (not ssl_supports_nonblocking , "SSL doesn't support non-blocking" )
134
+ def test_ssl_sockets (self ):
135
+ do_test (self , socket .getaddrinfo ("micropython.org" , 443 )[0 ][- 1 ], True , True )
136
+
137
+
138
+ if __name__ == "__main__" :
139
+ unittest .main ()
0 commit comments