39
39
#include "mbedtls/include/mbedtls/platform.h"
40
40
#include "mbedtls/include/mbedtls/net.h"
41
41
#include "mbedtls/include/mbedtls/ssl.h"
42
+ #include "mbedtls/include/mbedtls/x509_crt.h"
43
+ #include "mbedtls/include/mbedtls/pk.h"
42
44
#include "mbedtls/include/mbedtls/entropy.h"
43
45
#include "mbedtls/include/mbedtls/ctr_drbg.h"
44
46
#include "mbedtls/include/mbedtls/debug.h"
@@ -51,8 +53,16 @@ typedef struct _mp_obj_ssl_socket_t {
51
53
mbedtls_ssl_context ssl ;
52
54
mbedtls_ssl_config conf ;
53
55
mbedtls_x509_crt cacert ;
56
+ mbedtls_x509_crt cert ;
57
+ mbedtls_pk_context pkey ;
54
58
} mp_obj_ssl_socket_t ;
55
59
60
+ struct ssl_args {
61
+ mp_arg_val_t key ;
62
+ mp_arg_val_t cert ;
63
+ mp_arg_val_t server_side ;
64
+ };
65
+
56
66
STATIC const mp_obj_type_t ussl_socket_type ;
57
67
58
68
static void mbedtls_debug (void * ctx , int level , const char * file , int line , const char * str ) {
@@ -94,14 +104,16 @@ int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
94
104
}
95
105
96
106
97
- STATIC mp_obj_ssl_socket_t * socket_new (mp_obj_t sock , bool server_side ) {
107
+ STATIC mp_obj_ssl_socket_t * socket_new (mp_obj_t sock , struct ssl_args * args ) {
98
108
mp_obj_ssl_socket_t * o = m_new_obj (mp_obj_ssl_socket_t );
99
109
o -> base .type = & ussl_socket_type ;
100
110
101
111
int ret ;
102
112
mbedtls_ssl_init (& o -> ssl );
103
113
mbedtls_ssl_config_init (& o -> conf );
104
114
mbedtls_x509_crt_init (& o -> cacert );
115
+ mbedtls_x509_crt_init (& o -> cert );
116
+ mbedtls_pk_init (& o -> pkey );
105
117
mbedtls_ctr_drbg_init (& o -> ctr_drbg );
106
118
// Debug level (0-4)
107
119
mbedtls_debug_set_threshold (0 );
@@ -140,7 +152,24 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, bool server_side) {
140
152
o -> sock = sock ;
141
153
mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
142
154
143
- if (server_side ) {
155
+ if (args -> key .u_obj != MP_OBJ_NULL ) {
156
+ mp_uint_t key_len ;
157
+ const byte * key = (const byte * )mp_obj_str_get_data (args -> key .u_obj , & key_len );
158
+ // len should include terminating null
159
+ ret = mbedtls_pk_parse_key (& o -> pkey , key , key_len + 1 , NULL , 0 );
160
+ assert (ret == 0 );
161
+
162
+ mp_uint_t cert_len ;
163
+ const byte * cert = (const byte * )mp_obj_str_get_data (args -> cert .u_obj , & cert_len );
164
+ // len should include terminating null
165
+ ret = mbedtls_x509_crt_parse (& o -> cert , cert , cert_len + 1 );
166
+ assert (ret == 0 );
167
+
168
+ ret = mbedtls_ssl_conf_own_cert (& o -> conf , & o -> cert , & o -> pkey );
169
+ assert (ret == 0 );
170
+ }
171
+
172
+ if (args -> server_side .u_bool ) {
144
173
assert (0 );
145
174
} else {
146
175
while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
@@ -228,19 +257,19 @@ STATIC const mp_obj_type_t ussl_socket_type = {
228
257
STATIC mp_obj_t mod_ssl_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
229
258
// TODO: Implement more args
230
259
static const mp_arg_t allowed_args [] = {
260
+ { MP_QSTR_key , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_obj = MP_OBJ_NULL } },
261
+ { MP_QSTR_cert , MP_ARG_KW_ONLY | MP_ARG_OBJ , {.u_obj = MP_OBJ_NULL } },
231
262
{ MP_QSTR_server_side , MP_ARG_KW_ONLY | MP_ARG_BOOL , {.u_bool = false} },
232
263
};
233
264
234
265
// TODO: Check that sock implements stream protocol
235
266
mp_obj_t sock = pos_args [0 ];
236
267
237
- struct {
238
- mp_arg_val_t server_side ;
239
- } args ;
268
+ struct ssl_args args ;
240
269
mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args ,
241
270
MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
242
271
243
- return MP_OBJ_FROM_PTR (socket_new (sock , args . server_side . u_bool ));
272
+ return MP_OBJ_FROM_PTR (socket_new (sock , & args ));
244
273
}
245
274
STATIC MP_DEFINE_CONST_FUN_OBJ_KW (mod_ssl_wrap_socket_obj , 1 , mod_ssl_wrap_socket );
246
275
0 commit comments