15
15
#include <crypto/internal/hash.h>
16
16
#include <crypto/internal/simd.h>
17
17
#include <crypto/internal/skcipher.h>
18
+ #include <crypto/scatterwalk.h>
18
19
#include <linux/module.h>
19
20
#include <linux/cpufeature.h>
20
21
#include <crypto/xts.h>
31
32
#define aes_ecb_decrypt ce_aes_ecb_decrypt
32
33
#define aes_cbc_encrypt ce_aes_cbc_encrypt
33
34
#define aes_cbc_decrypt ce_aes_cbc_decrypt
35
+ #define aes_cbc_cts_encrypt ce_aes_cbc_cts_encrypt
36
+ #define aes_cbc_cts_decrypt ce_aes_cbc_cts_decrypt
34
37
#define aes_ctr_encrypt ce_aes_ctr_encrypt
35
38
#define aes_xts_encrypt ce_aes_xts_encrypt
36
39
#define aes_xts_decrypt ce_aes_xts_decrypt
@@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
45
48
#define aes_ecb_decrypt neon_aes_ecb_decrypt
46
49
#define aes_cbc_encrypt neon_aes_cbc_encrypt
47
50
#define aes_cbc_decrypt neon_aes_cbc_decrypt
51
+ #define aes_cbc_cts_encrypt neon_aes_cbc_cts_encrypt
52
+ #define aes_cbc_cts_decrypt neon_aes_cbc_cts_decrypt
48
53
#define aes_ctr_encrypt neon_aes_ctr_encrypt
49
54
#define aes_xts_encrypt neon_aes_xts_encrypt
50
55
#define aes_xts_decrypt neon_aes_xts_decrypt
@@ -73,6 +78,11 @@ asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
73
78
asmlinkage void aes_cbc_decrypt (u8 out [], u8 const in [], u32 const rk [],
74
79
int rounds , int blocks , u8 iv []);
75
80
81
+ asmlinkage void aes_cbc_cts_encrypt (u8 out [], u8 const in [], u32 const rk [],
82
+ int rounds , int bytes , u8 const iv []);
83
+ asmlinkage void aes_cbc_cts_decrypt (u8 out [], u8 const in [], u32 const rk [],
84
+ int rounds , int bytes , u8 const iv []);
85
+
76
86
asmlinkage void aes_ctr_encrypt (u8 out [], u8 const in [], u32 const rk [],
77
87
int rounds , int blocks , u8 ctr []);
78
88
@@ -87,6 +97,12 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
87
97
int blocks , u8 dg [], int enc_before ,
88
98
int enc_after );
89
99
100
+ struct cts_cbc_req_ctx {
101
+ struct scatterlist sg_src [2 ];
102
+ struct scatterlist sg_dst [2 ];
103
+ struct skcipher_request subreq ;
104
+ };
105
+
90
106
struct crypto_aes_xts_ctx {
91
107
struct crypto_aes_ctx key1 ;
92
108
struct crypto_aes_ctx __aligned (8 ) key2 ;
@@ -209,6 +225,136 @@ static int cbc_decrypt(struct skcipher_request *req)
209
225
return err ;
210
226
}
211
227
228
+ static int cts_cbc_init_tfm (struct crypto_skcipher * tfm )
229
+ {
230
+ crypto_skcipher_set_reqsize (tfm , sizeof (struct cts_cbc_req_ctx ));
231
+ return 0 ;
232
+ }
233
+
234
+ static int cts_cbc_encrypt (struct skcipher_request * req )
235
+ {
236
+ struct crypto_skcipher * tfm = crypto_skcipher_reqtfm (req );
237
+ struct crypto_aes_ctx * ctx = crypto_skcipher_ctx (tfm );
238
+ struct cts_cbc_req_ctx * rctx = skcipher_request_ctx (req );
239
+ int err , rounds = 6 + ctx -> key_length / 4 ;
240
+ int cbc_blocks = DIV_ROUND_UP (req -> cryptlen , AES_BLOCK_SIZE ) - 2 ;
241
+ struct scatterlist * src = req -> src , * dst = req -> dst ;
242
+ struct skcipher_walk walk ;
243
+
244
+ skcipher_request_set_tfm (& rctx -> subreq , tfm );
245
+
246
+ if (req -> cryptlen == AES_BLOCK_SIZE )
247
+ cbc_blocks = 1 ;
248
+
249
+ if (cbc_blocks > 0 ) {
250
+ unsigned int blocks ;
251
+
252
+ skcipher_request_set_crypt (& rctx -> subreq , req -> src , req -> dst ,
253
+ cbc_blocks * AES_BLOCK_SIZE ,
254
+ req -> iv );
255
+
256
+ err = skcipher_walk_virt (& walk , & rctx -> subreq , false);
257
+
258
+ while ((blocks = (walk .nbytes / AES_BLOCK_SIZE ))) {
259
+ kernel_neon_begin ();
260
+ aes_cbc_encrypt (walk .dst .virt .addr , walk .src .virt .addr ,
261
+ ctx -> key_enc , rounds , blocks , walk .iv );
262
+ kernel_neon_end ();
263
+ err = skcipher_walk_done (& walk ,
264
+ walk .nbytes % AES_BLOCK_SIZE );
265
+ }
266
+ if (err )
267
+ return err ;
268
+
269
+ if (req -> cryptlen == AES_BLOCK_SIZE )
270
+ return 0 ;
271
+
272
+ dst = src = scatterwalk_ffwd (rctx -> sg_src , req -> src ,
273
+ rctx -> subreq .cryptlen );
274
+ if (req -> dst != req -> src )
275
+ dst = scatterwalk_ffwd (rctx -> sg_dst , req -> dst ,
276
+ rctx -> subreq .cryptlen );
277
+ }
278
+
279
+ /* handle ciphertext stealing */
280
+ skcipher_request_set_crypt (& rctx -> subreq , src , dst ,
281
+ req -> cryptlen - cbc_blocks * AES_BLOCK_SIZE ,
282
+ req -> iv );
283
+
284
+ err = skcipher_walk_virt (& walk , & rctx -> subreq , false);
285
+ if (err )
286
+ return err ;
287
+
288
+ kernel_neon_begin ();
289
+ aes_cbc_cts_encrypt (walk .dst .virt .addr , walk .src .virt .addr ,
290
+ ctx -> key_enc , rounds , walk .nbytes , walk .iv );
291
+ kernel_neon_end ();
292
+
293
+ return skcipher_walk_done (& walk , 0 );
294
+ }
295
+
296
+ static int cts_cbc_decrypt (struct skcipher_request * req )
297
+ {
298
+ struct crypto_skcipher * tfm = crypto_skcipher_reqtfm (req );
299
+ struct crypto_aes_ctx * ctx = crypto_skcipher_ctx (tfm );
300
+ struct cts_cbc_req_ctx * rctx = skcipher_request_ctx (req );
301
+ int err , rounds = 6 + ctx -> key_length / 4 ;
302
+ int cbc_blocks = DIV_ROUND_UP (req -> cryptlen , AES_BLOCK_SIZE ) - 2 ;
303
+ struct scatterlist * src = req -> src , * dst = req -> dst ;
304
+ struct skcipher_walk walk ;
305
+
306
+ skcipher_request_set_tfm (& rctx -> subreq , tfm );
307
+
308
+ if (req -> cryptlen == AES_BLOCK_SIZE )
309
+ cbc_blocks = 1 ;
310
+
311
+ if (cbc_blocks > 0 ) {
312
+ unsigned int blocks ;
313
+
314
+ skcipher_request_set_crypt (& rctx -> subreq , req -> src , req -> dst ,
315
+ cbc_blocks * AES_BLOCK_SIZE ,
316
+ req -> iv );
317
+
318
+ err = skcipher_walk_virt (& walk , & rctx -> subreq , false);
319
+
320
+ while ((blocks = (walk .nbytes / AES_BLOCK_SIZE ))) {
321
+ kernel_neon_begin ();
322
+ aes_cbc_decrypt (walk .dst .virt .addr , walk .src .virt .addr ,
323
+ ctx -> key_dec , rounds , blocks , walk .iv );
324
+ kernel_neon_end ();
325
+ err = skcipher_walk_done (& walk ,
326
+ walk .nbytes % AES_BLOCK_SIZE );
327
+ }
328
+ if (err )
329
+ return err ;
330
+
331
+ if (req -> cryptlen == AES_BLOCK_SIZE )
332
+ return 0 ;
333
+
334
+ dst = src = scatterwalk_ffwd (rctx -> sg_src , req -> src ,
335
+ rctx -> subreq .cryptlen );
336
+ if (req -> dst != req -> src )
337
+ dst = scatterwalk_ffwd (rctx -> sg_dst , req -> dst ,
338
+ rctx -> subreq .cryptlen );
339
+ }
340
+
341
+ /* handle ciphertext stealing */
342
+ skcipher_request_set_crypt (& rctx -> subreq , src , dst ,
343
+ req -> cryptlen - cbc_blocks * AES_BLOCK_SIZE ,
344
+ req -> iv );
345
+
346
+ err = skcipher_walk_virt (& walk , & rctx -> subreq , false);
347
+ if (err )
348
+ return err ;
349
+
350
+ kernel_neon_begin ();
351
+ aes_cbc_cts_decrypt (walk .dst .virt .addr , walk .src .virt .addr ,
352
+ ctx -> key_dec , rounds , walk .nbytes , walk .iv );
353
+ kernel_neon_end ();
354
+
355
+ return skcipher_walk_done (& walk , 0 );
356
+ }
357
+
212
358
static int ctr_encrypt (struct skcipher_request * req )
213
359
{
214
360
struct crypto_skcipher * tfm = crypto_skcipher_reqtfm (req );
@@ -334,6 +480,25 @@ static struct skcipher_alg aes_algs[] = { {
334
480
.setkey = skcipher_aes_setkey ,
335
481
.encrypt = cbc_encrypt ,
336
482
.decrypt = cbc_decrypt ,
483
+ }, {
484
+ .base = {
485
+ .cra_name = "__cts(cbc(aes))" ,
486
+ .cra_driver_name = "__cts-cbc-aes-" MODE ,
487
+ .cra_priority = PRIO ,
488
+ .cra_flags = CRYPTO_ALG_INTERNAL ,
489
+ .cra_blocksize = 1 ,
490
+ .cra_ctxsize = sizeof (struct crypto_aes_ctx ),
491
+ .cra_module = THIS_MODULE ,
492
+ },
493
+ .min_keysize = AES_MIN_KEY_SIZE ,
494
+ .max_keysize = AES_MAX_KEY_SIZE ,
495
+ .ivsize = AES_BLOCK_SIZE ,
496
+ .chunksize = AES_BLOCK_SIZE ,
497
+ .walksize = 2 * AES_BLOCK_SIZE ,
498
+ .setkey = skcipher_aes_setkey ,
499
+ .encrypt = cts_cbc_encrypt ,
500
+ .decrypt = cts_cbc_decrypt ,
501
+ .init = cts_cbc_init_tfm ,
337
502
}, {
338
503
.base = {
339
504
.cra_name = "__ctr(aes)" ,
0 commit comments