@@ -50,6 +50,13 @@ struct tls_decrypt_arg {
50
50
u8 tail ;
51
51
};
52
52
53
+ struct tls_decrypt_ctx {
54
+ u8 iv [MAX_IV_SIZE ];
55
+ u8 aad [TLS_MAX_AAD_SIZE ];
56
+ u8 tail ;
57
+ struct scatterlist sg [];
58
+ };
59
+
53
60
noinline void tls_err_abort (struct sock * sk , int err )
54
61
{
55
62
WARN_ON_ONCE (err >= 0 );
@@ -1414,17 +1421,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1414
1421
struct tls_context * tls_ctx = tls_get_ctx (sk );
1415
1422
struct tls_sw_context_rx * ctx = tls_sw_ctx_rx (tls_ctx );
1416
1423
struct tls_prot_info * prot = & tls_ctx -> prot_info ;
1424
+ int n_sgin , n_sgout , aead_size , err , pages = 0 ;
1417
1425
struct strp_msg * rxm = strp_msg (skb );
1418
1426
struct tls_msg * tlm = tls_msg (skb );
1419
- int n_sgin , n_sgout , nsg , mem_size , aead_size , err , pages = 0 ;
1420
- u8 * aad , * iv , * tail , * mem = NULL ;
1421
1427
struct aead_request * aead_req ;
1422
1428
struct sk_buff * unused ;
1423
1429
struct scatterlist * sgin = NULL ;
1424
1430
struct scatterlist * sgout = NULL ;
1425
1431
const int data_len = rxm -> full_len - prot -> overhead_size ;
1426
1432
int tail_pages = !!prot -> tail_size ;
1433
+ struct tls_decrypt_ctx * dctx ;
1427
1434
int iv_offset = 0 ;
1435
+ u8 * mem ;
1428
1436
1429
1437
if (darg -> zc && (out_iov || out_sg )) {
1430
1438
if (out_iov )
@@ -1446,67 +1454,59 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1446
1454
/* Increment to accommodate AAD */
1447
1455
n_sgin = n_sgin + 1 ;
1448
1456
1449
- nsg = n_sgin + n_sgout ;
1450
-
1451
- aead_size = sizeof (* aead_req ) + crypto_aead_reqsize (ctx -> aead_recv );
1452
- mem_size = aead_size + (nsg * sizeof (struct scatterlist ));
1453
- mem_size = mem_size + TLS_MAX_AAD_SIZE ;
1454
- mem_size = mem_size + MAX_IV_SIZE ;
1455
- mem_size = mem_size + prot -> tail_size ;
1456
-
1457
1457
/* Allocate a single block of memory which contains
1458
- * aead_req || sgin[] || sgout[] || aad || iv || tail .
1459
- * This order achieves correct alignment for aead_req, sgin, sgout .
1458
+ * aead_req || tls_decrypt_ctx .
1459
+ * Both structs are variable length .
1460
1460
*/
1461
- mem = kmalloc (mem_size , sk -> sk_allocation );
1461
+ aead_size = sizeof (* aead_req ) + crypto_aead_reqsize (ctx -> aead_recv );
1462
+ mem = kmalloc (aead_size + struct_size (dctx , sg , n_sgin + n_sgout ),
1463
+ sk -> sk_allocation );
1462
1464
if (!mem )
1463
1465
return - ENOMEM ;
1464
1466
1465
1467
/* Segment the allocated memory */
1466
1468
aead_req = (struct aead_request * )mem ;
1467
- sgin = (struct scatterlist * )(mem + aead_size );
1468
- sgout = sgin + n_sgin ;
1469
- aad = (u8 * )(sgout + n_sgout );
1470
- iv = aad + TLS_MAX_AAD_SIZE ;
1471
- tail = iv + MAX_IV_SIZE ;
1469
+ dctx = (struct tls_decrypt_ctx * )(mem + aead_size );
1470
+ sgin = & dctx -> sg [0 ];
1471
+ sgout = & dctx -> sg [n_sgin ];
1472
1472
1473
1473
/* For CCM based ciphers, first byte of nonce+iv is a constant */
1474
1474
switch (prot -> cipher_type ) {
1475
1475
case TLS_CIPHER_AES_CCM_128 :
1476
- iv [0 ] = TLS_AES_CCM_IV_B0_BYTE ;
1476
+ dctx -> iv [0 ] = TLS_AES_CCM_IV_B0_BYTE ;
1477
1477
iv_offset = 1 ;
1478
1478
break ;
1479
1479
case TLS_CIPHER_SM4_CCM :
1480
- iv [0 ] = TLS_SM4_CCM_IV_B0_BYTE ;
1480
+ dctx -> iv [0 ] = TLS_SM4_CCM_IV_B0_BYTE ;
1481
1481
iv_offset = 1 ;
1482
1482
break ;
1483
1483
}
1484
1484
1485
1485
/* Prepare IV */
1486
1486
if (prot -> version == TLS_1_3_VERSION ||
1487
1487
prot -> cipher_type == TLS_CIPHER_CHACHA20_POLY1305 ) {
1488
- memcpy (iv + iv_offset , tls_ctx -> rx .iv ,
1488
+ memcpy (& dctx -> iv [ iv_offset ] , tls_ctx -> rx .iv ,
1489
1489
prot -> iv_size + prot -> salt_size );
1490
1490
} else {
1491
1491
err = skb_copy_bits (skb , rxm -> offset + TLS_HEADER_SIZE ,
1492
- iv + iv_offset + prot -> salt_size ,
1492
+ & dctx -> iv [ iv_offset ] + prot -> salt_size ,
1493
1493
prot -> iv_size );
1494
1494
if (err < 0 ) {
1495
1495
kfree (mem );
1496
1496
return err ;
1497
1497
}
1498
- memcpy (iv + iv_offset , tls_ctx -> rx .iv , prot -> salt_size );
1498
+ memcpy (& dctx -> iv [ iv_offset ] , tls_ctx -> rx .iv , prot -> salt_size );
1499
1499
}
1500
- xor_iv_with_seq (prot , iv + iv_offset , tls_ctx -> rx .rec_seq );
1500
+ xor_iv_with_seq (prot , & dctx -> iv [ iv_offset ] , tls_ctx -> rx .rec_seq );
1501
1501
1502
1502
/* Prepare AAD */
1503
- tls_make_aad (aad , rxm -> full_len - prot -> overhead_size +
1503
+ tls_make_aad (dctx -> aad , rxm -> full_len - prot -> overhead_size +
1504
1504
prot -> tail_size ,
1505
1505
tls_ctx -> rx .rec_seq , tlm -> control , prot );
1506
1506
1507
1507
/* Prepare sgin */
1508
1508
sg_init_table (sgin , n_sgin );
1509
- sg_set_buf (& sgin [0 ], aad , prot -> aad_size );
1509
+ sg_set_buf (& sgin [0 ], dctx -> aad , prot -> aad_size );
1510
1510
err = skb_to_sgvec (skb , & sgin [1 ],
1511
1511
rxm -> offset + prot -> prepend_size ,
1512
1512
rxm -> full_len - prot -> prepend_size );
@@ -1518,7 +1518,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1518
1518
if (n_sgout ) {
1519
1519
if (out_iov ) {
1520
1520
sg_init_table (sgout , n_sgout );
1521
- sg_set_buf (& sgout [0 ], aad , prot -> aad_size );
1521
+ sg_set_buf (& sgout [0 ], dctx -> aad , prot -> aad_size );
1522
1522
1523
1523
err = tls_setup_from_iter (out_iov , data_len ,
1524
1524
& pages , & sgout [1 ],
@@ -1528,7 +1528,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1528
1528
1529
1529
if (prot -> tail_size ) {
1530
1530
sg_unmark_end (& sgout [pages ]);
1531
- sg_set_buf (& sgout [pages + 1 ], tail ,
1531
+ sg_set_buf (& sgout [pages + 1 ], & dctx -> tail ,
1532
1532
prot -> tail_size );
1533
1533
sg_mark_end (& sgout [pages + 1 ]);
1534
1534
}
@@ -1545,13 +1545,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1545
1545
}
1546
1546
1547
1547
/* Prepare and submit AEAD request */
1548
- err = tls_do_decryption (sk , skb , sgin , sgout , iv ,
1548
+ err = tls_do_decryption (sk , skb , sgin , sgout , dctx -> iv ,
1549
1549
data_len + prot -> tail_size , aead_req , darg );
1550
1550
if (darg -> async )
1551
1551
return 0 ;
1552
1552
1553
1553
if (prot -> tail_size )
1554
- darg -> tail = * tail ;
1554
+ darg -> tail = dctx -> tail ;
1555
1555
1556
1556
/* Release the pages in case iov was mapped to pages */
1557
1557
for (; pages > 0 ; pages -- )
0 commit comments