@@ -51,12 +51,9 @@ enum {
51
51
TLSV6 ,
52
52
TLS_NUM_PROTS ,
53
53
};
54
-
55
54
enum {
56
55
TLS_BASE ,
57
- TLS_SW_TX ,
58
- TLS_SW_RX ,
59
- TLS_SW_RXTX ,
56
+ TLS_SW ,
60
57
TLS_HW_RECORD ,
61
58
TLS_NUM_CONFIG ,
62
59
};
@@ -65,14 +62,14 @@ static struct proto *saved_tcpv6_prot;
65
62
static DEFINE_MUTEX (tcpv6_prot_mutex );
66
63
static LIST_HEAD (device_list );
67
64
static DEFINE_MUTEX (device_mutex );
68
- static struct proto tls_prots [TLS_NUM_PROTS ][TLS_NUM_CONFIG ];
65
+ static struct proto tls_prots [TLS_NUM_PROTS ][TLS_NUM_CONFIG ][ TLS_NUM_CONFIG ] ;
69
66
static struct proto_ops tls_sw_proto_ops ;
70
67
71
- static inline void update_sk_prot (struct sock * sk , struct tls_context * ctx )
68
+ static void update_sk_prot (struct sock * sk , struct tls_context * ctx )
72
69
{
73
70
int ip_ver = sk -> sk_family == AF_INET6 ? TLSV6 : TLSV4 ;
74
71
75
- sk -> sk_prot = & tls_prots [ip_ver ][ctx -> conf ];
72
+ sk -> sk_prot = & tls_prots [ip_ver ][ctx -> tx_conf ][ ctx -> rx_conf ];
76
73
}
77
74
78
75
int wait_on_pending_writer (struct sock * sk , long * timeo )
@@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
245
242
lock_sock (sk );
246
243
sk_proto_close = ctx -> sk_proto_close ;
247
244
248
- if (ctx -> conf == TLS_HW_RECORD )
245
+ if (ctx -> tx_conf == TLS_HW_RECORD && ctx -> rx_conf == TLS_HW_RECORD )
249
246
goto skip_tx_cleanup ;
250
247
251
- if (ctx -> conf == TLS_BASE ) {
248
+ if (ctx -> tx_conf == TLS_BASE && ctx -> rx_conf == TLS_BASE ) {
252
249
kfree (ctx );
253
250
ctx = NULL ;
254
251
goto skip_tx_cleanup ;
@@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
270
267
}
271
268
}
272
269
273
- kfree (ctx -> tx .rec_seq );
274
- kfree (ctx -> tx .iv );
275
- kfree (ctx -> rx .rec_seq );
276
- kfree (ctx -> rx .iv );
270
+ /* We need these for tls_sw_fallback handling of other packets */
271
+ if (ctx -> tx_conf == TLS_SW ) {
272
+ kfree (ctx -> tx .rec_seq );
273
+ kfree (ctx -> tx .iv );
274
+ tls_sw_free_resources_tx (sk );
275
+ }
277
276
278
- if (ctx -> conf == TLS_SW_TX ||
279
- ctx -> conf == TLS_SW_RX ||
280
- ctx -> conf == TLS_SW_RXTX ) {
281
- tls_sw_free_resources (sk );
277
+ if (ctx -> rx_conf == TLS_SW ) {
278
+ kfree ( ctx -> rx . rec_seq );
279
+ kfree ( ctx -> rx . iv );
280
+ tls_sw_free_resources_rx (sk );
282
281
}
283
282
284
283
skip_tx_cleanup :
@@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
287
286
/* free ctx for TLS_HW_RECORD, used by tcp_set_state
288
287
* for sk->sk_prot->unhash [tls_hw_unhash]
289
288
*/
290
- if (ctx && ctx -> conf == TLS_HW_RECORD )
289
+ if (ctx && ctx -> tx_conf == TLS_HW_RECORD &&
290
+ ctx -> rx_conf == TLS_HW_RECORD )
291
291
kfree (ctx );
292
292
}
293
293
@@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
441
441
goto err_crypto_info ;
442
442
}
443
443
444
- /* currently SW is default, we will have ethtool in future */
445
444
if (tx ) {
446
445
rc = tls_set_sw_offload (sk , ctx , 1 );
447
- if (ctx -> conf == TLS_SW_RX )
448
- conf = TLS_SW_RXTX ;
449
- else
450
- conf = TLS_SW_TX ;
446
+ conf = TLS_SW ;
451
447
} else {
452
448
rc = tls_set_sw_offload (sk , ctx , 0 );
453
- if (ctx -> conf == TLS_SW_TX )
454
- conf = TLS_SW_RXTX ;
455
- else
456
- conf = TLS_SW_RX ;
449
+ conf = TLS_SW ;
457
450
}
458
451
459
452
if (rc )
460
453
goto err_crypto_info ;
461
454
462
- ctx -> conf = conf ;
455
+ if (tx )
456
+ ctx -> tx_conf = conf ;
457
+ else
458
+ ctx -> rx_conf = conf ;
463
459
update_sk_prot (sk , ctx );
464
460
if (tx ) {
465
461
ctx -> sk_write_space = sk -> sk_write_space ;
@@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk)
535
531
ctx -> hash = sk -> sk_prot -> hash ;
536
532
ctx -> unhash = sk -> sk_prot -> unhash ;
537
533
ctx -> sk_proto_close = sk -> sk_prot -> close ;
538
- ctx -> conf = TLS_HW_RECORD ;
534
+ ctx -> rx_conf = TLS_HW_RECORD ;
535
+ ctx -> tx_conf = TLS_HW_RECORD ;
539
536
update_sk_prot (sk , ctx );
540
537
rc = 1 ;
541
538
break ;
@@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk)
579
576
return err ;
580
577
}
581
578
582
- static void build_protos (struct proto * prot , struct proto * base )
579
+ static void build_protos (struct proto prot [TLS_NUM_CONFIG ][TLS_NUM_CONFIG ],
580
+ struct proto * base )
583
581
{
584
- prot [TLS_BASE ] = * base ;
585
- prot [TLS_BASE ].setsockopt = tls_setsockopt ;
586
- prot [TLS_BASE ].getsockopt = tls_getsockopt ;
587
- prot [TLS_BASE ].close = tls_sk_proto_close ;
588
-
589
- prot [TLS_SW_TX ] = prot [TLS_BASE ];
590
- prot [TLS_SW_TX ].sendmsg = tls_sw_sendmsg ;
591
- prot [TLS_SW_TX ] .sendpage = tls_sw_sendpage ;
592
-
593
- prot [TLS_SW_RX ] = prot [TLS_BASE ];
594
- prot [TLS_SW_RX ].recvmsg = tls_sw_recvmsg ;
595
- prot [TLS_SW_RX ].close = tls_sk_proto_close ;
596
-
597
- prot [TLS_SW_RXTX ] = prot [TLS_SW_TX ];
598
- prot [TLS_SW_RXTX ].recvmsg = tls_sw_recvmsg ;
599
- prot [TLS_SW_RXTX ] .close = tls_sk_proto_close ;
600
-
601
- prot [TLS_HW_RECORD ] = * base ;
602
- prot [TLS_HW_RECORD ].hash = tls_hw_hash ;
603
- prot [TLS_HW_RECORD ].unhash = tls_hw_unhash ;
604
- prot [TLS_HW_RECORD ].close = tls_sk_proto_close ;
582
+ prot [TLS_BASE ][ TLS_BASE ] = * base ;
583
+ prot [TLS_BASE ][ TLS_BASE ] .setsockopt = tls_setsockopt ;
584
+ prot [TLS_BASE ][ TLS_BASE ] .getsockopt = tls_getsockopt ;
585
+ prot [TLS_BASE ][ TLS_BASE ] .close = tls_sk_proto_close ;
586
+
587
+ prot [TLS_SW ][ TLS_BASE ] = prot [ TLS_BASE ] [TLS_BASE ];
588
+ prot [TLS_SW ][ TLS_BASE ].sendmsg = tls_sw_sendmsg ;
589
+ prot [TLS_SW ][ TLS_BASE ] .sendpage = tls_sw_sendpage ;
590
+
591
+ prot [TLS_BASE ][ TLS_SW ] = prot [ TLS_BASE ] [TLS_BASE ];
592
+ prot [TLS_BASE ][ TLS_SW ].recvmsg = tls_sw_recvmsg ;
593
+ prot [TLS_BASE ][ TLS_SW ].close = tls_sk_proto_close ;
594
+
595
+ prot [TLS_SW ][ TLS_SW ] = prot [TLS_SW ][ TLS_BASE ];
596
+ prot [TLS_SW ][ TLS_SW ].recvmsg = tls_sw_recvmsg ;
597
+ prot [TLS_SW ][ TLS_SW ] .close = tls_sk_proto_close ;
598
+
599
+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] = * base ;
600
+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .hash = tls_hw_hash ;
601
+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .unhash = tls_hw_unhash ;
602
+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .close = tls_sk_proto_close ;
605
603
}
606
604
607
605
static int tls_init (struct sock * sk )
@@ -643,7 +641,8 @@ static int tls_init(struct sock *sk)
643
641
mutex_unlock (& tcpv6_prot_mutex );
644
642
}
645
643
646
- ctx -> conf = TLS_BASE ;
644
+ ctx -> tx_conf = TLS_BASE ;
645
+ ctx -> rx_conf = TLS_BASE ;
647
646
update_sk_prot (sk , ctx );
648
647
out :
649
648
return rc ;
0 commit comments