Skip to content

Commit f014564

Browse files
committed
extmod/modtls_mbedtls: Test SSLSession reuse.
Signed-off-by: Daniël van de Giessen <daniel@dvdgiessen.nl>
1 parent 83188bb commit f014564

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Test creating an SSL connection with certificates as bytes objects.
2+
3+
try:
4+
from io import IOBase
5+
import os
6+
import socket
7+
import ssl
8+
except ImportError:
9+
print("SKIP")
10+
raise SystemExit
11+
12+
if not hasattr(ssl, "SSLSession"):
13+
print("SKIP")
14+
raise SystemExit
15+
16+
PORT = 8000
17+
18+
# These are test certificates. See tests/README.md for details.
19+
certfile = "ec_cert.der"
20+
keyfile = "ec_key.der"
21+
22+
try:
23+
os.stat(certfile)
24+
os.stat(keyfile)
25+
except OSError:
26+
print("SKIP")
27+
raise SystemExit
28+
29+
with open(certfile, "rb") as cf:
30+
cert = cadata = cf.read()
31+
32+
with open(keyfile, "rb") as kf:
33+
key = kf.read()
34+
35+
# Helper class to count number of bytes going over a TCP socket
36+
class CountingStream(IOBase):
37+
def __init__(self, stream):
38+
self.stream = stream
39+
self.count = 0
40+
41+
def readinto(self, buf, nbytes=None):
42+
result = self.stream.readinto(buf) if nbytes is None else self.stream.readinto(buf, nbytes)
43+
self.count += result
44+
return result
45+
46+
def write(self, buf):
47+
self.count += len(buf)
48+
return self.stream.write(buf)
49+
50+
def ioctl(self, req, arg):
51+
if hasattr(self.stream, "ioctl"):
52+
return self.stream.ioctl(req, arg)
53+
return 0
54+
55+
56+
# Server
57+
def instance0():
58+
multitest.globals(IP=multitest.get_network_ip())
59+
s = socket.socket()
60+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
61+
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
62+
s.listen(1)
63+
multitest.next()
64+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
65+
server_ctx.load_cert_chain(cert, key)
66+
for i in range(3):
67+
s2, _ = s.accept()
68+
s2 = server_ctx.wrap_socket(s2, server_side=True)
69+
print(s2.read(18))
70+
s2.write(b"server to client {}".format(i))
71+
s2.close()
72+
s.close()
73+
74+
75+
# Client
76+
def instance1():
77+
multitest.next()
78+
count_without_reuse = None
79+
session = None
80+
# Test three variations:
81+
# - 1st: no session reuse
82+
# - 2nd: direct session reuse
83+
# - 3rd: serialized session reuse
84+
for i in range(3):
85+
s = socket.socket()
86+
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
87+
s = CountingStream(s)
88+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
89+
client_ctx.verify_mode = ssl.CERT_REQUIRED
90+
client_ctx.load_verify_locations(cadata=cadata)
91+
if i >= 2:
92+
# Serialize and unserialize session object
93+
# before using it in the last variation
94+
session = ssl.SSLSession(bytes(session))
95+
s2 = client_ctx.wrap_socket(s, server_hostname="micropython.local", session=session)
96+
s2.write(b"client to server {}".format(i))
97+
print(s2.read(18))
98+
session = s2.session
99+
print(type(session))
100+
s2.close()
101+
if count_without_reuse is None:
102+
# Save byte count without session reuse
103+
count_without_reuse = s.count
104+
else:
105+
# Assert session reuse should use less
106+
print(s.count < count_without_reuse)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
--- instance0 ---
2+
b'client to server 0'
3+
b'client to server 1'
4+
b'client to server 2'
5+
--- instance1 ---
6+
b'server to client 0'
7+
<class 'SSLSession'>
8+
b'server to client 1'
9+
<class 'SSLSession'>
10+
True
11+
b'server to client 2'
12+
<class 'SSLSession'>
13+
True

0 commit comments

Comments
 (0)