Skip to content

Commit feae3a7

Browse files
committed
extmod/modtls_mbedtls: Test SSLSession reuse.
Signed-off-by: Daniël van de Giessen <daniel@dvdgiessen.nl>
1 parent 4ea1d63 commit feae3a7

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Test creating an SSL connection with certificates as bytes objects.
2+
3+
try:
4+
from io import IOBase
5+
import os
6+
import socket
7+
import ssl
8+
except ImportError:
9+
print("SKIP")
10+
raise SystemExit
11+
12+
if not hasattr(ssl, "SSLSession"):
13+
print("SKIP")
14+
raise SystemExit
15+
16+
PORT = 8000
17+
18+
# These are test certificates. See tests/README.md for details.
19+
certfile = "ec_cert.der"
20+
keyfile = "ec_key.der"
21+
22+
try:
23+
os.stat(certfile)
24+
os.stat(keyfile)
25+
except OSError:
26+
print("SKIP")
27+
raise SystemExit
28+
29+
with open(certfile, "rb") as cf:
30+
cert = cadata = cf.read()
31+
32+
with open(keyfile, "rb") as kf:
33+
key = kf.read()
34+
35+
36+
# Helper class to count number of bytes going over a TCP socket
37+
class CountingStream(IOBase):
38+
def __init__(self, stream):
39+
self.stream = stream
40+
self.count = 0
41+
42+
def readinto(self, buf, nbytes=None):
43+
result = self.stream.readinto(buf) if nbytes is None else self.stream.readinto(buf, nbytes)
44+
self.count += result
45+
return result
46+
47+
def write(self, buf):
48+
self.count += len(buf)
49+
return self.stream.write(buf)
50+
51+
def ioctl(self, req, arg):
52+
if hasattr(self.stream, "ioctl"):
53+
return self.stream.ioctl(req, arg)
54+
return 0
55+
56+
57+
# Server
58+
def instance0():
59+
multitest.globals(IP=multitest.get_network_ip())
60+
s = socket.socket()
61+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
62+
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
63+
s.listen(1)
64+
multitest.next()
65+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
66+
server_ctx.load_cert_chain(cert, key)
67+
for i in range(7):
68+
s2, _ = s.accept()
69+
s2 = server_ctx.wrap_socket(s2, server_side=True)
70+
print(s2.read(18))
71+
s2.write(b"server to client {}".format(i))
72+
s2.close()
73+
s.close()
74+
75+
76+
# Client
77+
def instance1():
78+
multitest.next()
79+
80+
def connect_and_count(i, session, set_method="wrap_socket"):
81+
s = socket.socket()
82+
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
83+
s = CountingStream(s)
84+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
85+
client_ctx.verify_mode = ssl.CERT_REQUIRED
86+
client_ctx.load_verify_locations(cadata=cadata)
87+
wrap_socket_kwargs = {}
88+
if set_method == "wrap_socket":
89+
wrap_socket_kwargs = {"session": session}
90+
elif set_method == "socket_attr":
91+
wrap_socket_kwargs = {"do_handshake_on_connect": False}
92+
s2 = client_ctx.wrap_socket(s, server_hostname="micropython.local", **wrap_socket_kwargs)
93+
if set_method == "socket_attr" and session is not None:
94+
s2.session = session
95+
s2.write(b"client to server {}".format(i))
96+
print(s2.read(18))
97+
session = s2.session
98+
print(type(session))
99+
s2.close()
100+
return session, s.count
101+
102+
# No session reuse
103+
session, count_without_reuse = connect_and_count(0, None)
104+
105+
# Direct session reuse
106+
session, count = connect_and_count(1, session, "wrap_socket")
107+
print(count < count_without_reuse)
108+
109+
# Serialized session reuse
110+
session = ssl.SSLSession(session.serialize())
111+
session, count = connect_and_count(2, session, "wrap_socket")
112+
print(count < count_without_reuse)
113+
114+
# Serialized session reuse (using buffer protocol)
115+
session = ssl.SSLSession(bytes(session))
116+
session, count = connect_and_count(3, session, "wrap_socket")
117+
print(count < count_without_reuse)
118+
119+
# Direct session reuse
120+
session, count = connect_and_count(4, session, "socket_attr")
121+
print(count < count_without_reuse)
122+
123+
# Serialized session reuse
124+
session = ssl.SSLSession(session.serialize())
125+
session, count = connect_and_count(5, session, "socket_attr")
126+
print(count < count_without_reuse)
127+
128+
# Serialized session reuse (using buffer protocol)
129+
session = ssl.SSLSession(bytes(session))
130+
session, count = connect_and_count(6, session, "socket_attr")
131+
print(count < count_without_reuse)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
--- instance0 ---
2+
b'client to server 0'
3+
b'client to server 1'
4+
b'client to server 2'
5+
b'client to server 3'
6+
b'client to server 4'
7+
b'client to server 5'
8+
b'client to server 6'
9+
--- instance1 ---
10+
b'server to client 0'
11+
<class 'SSLSession'>
12+
b'server to client 1'
13+
<class 'SSLSession'>
14+
True
15+
b'server to client 2'
16+
<class 'SSLSession'>
17+
True
18+
b'server to client 3'
19+
<class 'SSLSession'>
20+
True
21+
b'server to client 4'
22+
<class 'SSLSession'>
23+
True
24+
b'server to client 5'
25+
<class 'SSLSession'>
26+
True
27+
b'server to client 6'
28+
<class 'SSLSession'>
29+
True

0 commit comments

Comments
 (0)