41
41
#include <linux/mm.h>
42
42
#include <net/strparser.h>
43
43
#include <net/tcp.h>
44
+ #include <linux/ptr_ring.h>
45
+ #include <net/inet_common.h>
44
46
45
47
#define SOCK_CREATE_FLAG_MASK \
46
48
(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
@@ -82,6 +84,7 @@ struct smap_psock {
82
84
int sg_size ;
83
85
int eval ;
84
86
struct sk_msg_buff * cork ;
87
+ struct list_head ingress ;
85
88
86
89
struct strparser strp ;
87
90
struct bpf_prog * bpf_tx_msg ;
@@ -103,6 +106,8 @@ struct smap_psock {
103
106
};
104
107
105
108
static void smap_release_sock (struct smap_psock * psock , struct sock * sock );
109
+ static int bpf_tcp_recvmsg (struct sock * sk , struct msghdr * msg , size_t len ,
110
+ int nonblock , int flags , int * addr_len );
106
111
static int bpf_tcp_sendmsg (struct sock * sk , struct msghdr * msg , size_t size );
107
112
static int bpf_tcp_sendpage (struct sock * sk , struct page * page ,
108
113
int offset , size_t size , int flags );
@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
112
117
return rcu_dereference_sk_user_data (sk );
113
118
}
114
119
120
+ static bool bpf_tcp_stream_read (const struct sock * sk )
121
+ {
122
+ struct smap_psock * psock ;
123
+ bool empty = true;
124
+
125
+ rcu_read_lock ();
126
+ psock = smap_psock_sk (sk );
127
+ if (unlikely (!psock ))
128
+ goto out ;
129
+ empty = list_empty (& psock -> ingress );
130
+ out :
131
+ rcu_read_unlock ();
132
+ return !empty ;
133
+ }
134
+
115
135
static struct proto tcp_bpf_proto ;
116
136
static int bpf_tcp_init (struct sock * sk )
117
137
{
@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
135
155
if (psock -> bpf_tx_msg ) {
136
156
tcp_bpf_proto .sendmsg = bpf_tcp_sendmsg ;
137
157
tcp_bpf_proto .sendpage = bpf_tcp_sendpage ;
158
+ tcp_bpf_proto .recvmsg = bpf_tcp_recvmsg ;
159
+ tcp_bpf_proto .stream_memory_read = bpf_tcp_stream_read ;
138
160
}
139
161
140
162
sk -> sk_prot = & tcp_bpf_proto ;
@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
170
192
{
171
193
void (* close_fun )(struct sock * sk , long timeout );
172
194
struct smap_psock_map_entry * e , * tmp ;
195
+ struct sk_msg_buff * md , * mtmp ;
173
196
struct smap_psock * psock ;
174
197
struct sock * osk ;
175
198
@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
188
211
close_fun = psock -> save_close ;
189
212
190
213
write_lock_bh (& sk -> sk_callback_lock );
214
+ list_for_each_entry_safe (md , mtmp , & psock -> ingress , list ) {
215
+ list_del (& md -> list );
216
+ free_start_sg (psock -> sock , md );
217
+ kfree (md );
218
+ }
219
+
191
220
list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
192
221
osk = cmpxchg (e -> entry , sk , NULL );
193
222
if (osk == sk ) {
@@ -468,13 +497,80 @@ static unsigned int smap_do_tx_msg(struct sock *sk,
468
497
return _rc ;
469
498
}
470
499
500
+ static int bpf_tcp_ingress (struct sock * sk , int apply_bytes ,
501
+ struct smap_psock * psock ,
502
+ struct sk_msg_buff * md , int flags )
503
+ {
504
+ bool apply = apply_bytes ;
505
+ size_t size , copied = 0 ;
506
+ struct sk_msg_buff * r ;
507
+ int err = 0 , i ;
508
+
509
+ r = kzalloc (sizeof (struct sk_msg_buff ), __GFP_NOWARN | GFP_KERNEL );
510
+ if (unlikely (!r ))
511
+ return - ENOMEM ;
512
+
513
+ lock_sock (sk );
514
+ r -> sg_start = md -> sg_start ;
515
+ i = md -> sg_start ;
516
+
517
+ do {
518
+ r -> sg_data [i ] = md -> sg_data [i ];
519
+
520
+ size = (apply && apply_bytes < md -> sg_data [i ].length ) ?
521
+ apply_bytes : md -> sg_data [i ].length ;
522
+
523
+ if (!sk_wmem_schedule (sk , size )) {
524
+ if (!copied )
525
+ err = - ENOMEM ;
526
+ break ;
527
+ }
528
+
529
+ sk_mem_charge (sk , size );
530
+ r -> sg_data [i ].length = size ;
531
+ md -> sg_data [i ].length -= size ;
532
+ md -> sg_data [i ].offset += size ;
533
+ copied += size ;
534
+
535
+ if (md -> sg_data [i ].length ) {
536
+ get_page (sg_page (& r -> sg_data [i ]));
537
+ r -> sg_end = (i + 1 ) == MAX_SKB_FRAGS ? 0 : i + 1 ;
538
+ } else {
539
+ i ++ ;
540
+ if (i == MAX_SKB_FRAGS )
541
+ i = 0 ;
542
+ r -> sg_end = i ;
543
+ }
544
+
545
+ if (apply ) {
546
+ apply_bytes -= size ;
547
+ if (!apply_bytes )
548
+ break ;
549
+ }
550
+ } while (i != md -> sg_end );
551
+
552
+ md -> sg_start = i ;
553
+
554
+ if (!err ) {
555
+ list_add_tail (& r -> list , & psock -> ingress );
556
+ sk -> sk_data_ready (sk );
557
+ } else {
558
+ free_start_sg (sk , r );
559
+ kfree (r );
560
+ }
561
+
562
+ release_sock (sk );
563
+ return err ;
564
+ }
565
+
471
566
static int bpf_tcp_sendmsg_do_redirect (struct sock * sk , int send ,
472
567
struct sk_msg_buff * md ,
473
568
int flags )
474
569
{
475
570
struct smap_psock * psock ;
476
571
struct scatterlist * sg ;
477
572
int i , err , free = 0 ;
573
+ bool ingress = !!(md -> flags & BPF_F_INGRESS );
478
574
479
575
sg = md -> sg_data ;
480
576
@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
487
583
goto out_rcu ;
488
584
489
585
rcu_read_unlock ();
490
- lock_sock (sk );
491
- err = bpf_tcp_push (sk , send , md , flags , false);
492
- release_sock (sk );
586
+
587
+ if (ingress ) {
588
+ err = bpf_tcp_ingress (sk , send , psock , md , flags );
589
+ } else {
590
+ lock_sock (sk );
591
+ err = bpf_tcp_push (sk , send , md , flags , false);
592
+ release_sock (sk );
593
+ }
493
594
smap_release_sock (psock , sk );
494
595
if (unlikely (err ))
495
596
goto out ;
@@ -623,6 +724,89 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
623
724
return err ;
624
725
}
625
726
727
+ static int bpf_tcp_recvmsg (struct sock * sk , struct msghdr * msg , size_t len ,
728
+ int nonblock , int flags , int * addr_len )
729
+ {
730
+ struct iov_iter * iter = & msg -> msg_iter ;
731
+ struct smap_psock * psock ;
732
+ int copied = 0 ;
733
+
734
+ if (unlikely (flags & MSG_ERRQUEUE ))
735
+ return inet_recv_error (sk , msg , len , addr_len );
736
+
737
+ rcu_read_lock ();
738
+ psock = smap_psock_sk (sk );
739
+ if (unlikely (!psock ))
740
+ goto out ;
741
+
742
+ if (unlikely (!refcount_inc_not_zero (& psock -> refcnt )))
743
+ goto out ;
744
+ rcu_read_unlock ();
745
+
746
+ if (!skb_queue_empty (& sk -> sk_receive_queue ))
747
+ return tcp_recvmsg (sk , msg , len , nonblock , flags , addr_len );
748
+
749
+ lock_sock (sk );
750
+ while (copied != len ) {
751
+ struct scatterlist * sg ;
752
+ struct sk_msg_buff * md ;
753
+ int i ;
754
+
755
+ md = list_first_entry_or_null (& psock -> ingress ,
756
+ struct sk_msg_buff , list );
757
+ if (unlikely (!md ))
758
+ break ;
759
+ i = md -> sg_start ;
760
+ do {
761
+ struct page * page ;
762
+ int n , copy ;
763
+
764
+ sg = & md -> sg_data [i ];
765
+ copy = sg -> length ;
766
+ page = sg_page (sg );
767
+
768
+ if (copied + copy > len )
769
+ copy = len - copied ;
770
+
771
+ n = copy_page_to_iter (page , sg -> offset , copy , iter );
772
+ if (n != copy ) {
773
+ md -> sg_start = i ;
774
+ release_sock (sk );
775
+ smap_release_sock (psock , sk );
776
+ return - EFAULT ;
777
+ }
778
+
779
+ copied += copy ;
780
+ sg -> offset += copy ;
781
+ sg -> length -= copy ;
782
+ sk_mem_uncharge (sk , copy );
783
+
784
+ if (!sg -> length ) {
785
+ i ++ ;
786
+ if (i == MAX_SKB_FRAGS )
787
+ i = 0 ;
788
+ put_page (page );
789
+ }
790
+ if (copied == len )
791
+ break ;
792
+ } while (i != md -> sg_end );
793
+ md -> sg_start = i ;
794
+
795
+ if (!sg -> length && md -> sg_start == md -> sg_end ) {
796
+ list_del (& md -> list );
797
+ kfree (md );
798
+ }
799
+ }
800
+
801
+ release_sock (sk );
802
+ smap_release_sock (psock , sk );
803
+ return copied ;
804
+ out :
805
+ rcu_read_unlock ();
806
+ return tcp_recvmsg (sk , msg , len , nonblock , flags , addr_len );
807
+ }
808
+
809
+
626
810
static int bpf_tcp_sendmsg (struct sock * sk , struct msghdr * msg , size_t size )
627
811
{
628
812
int flags = msg -> msg_flags | MSG_NO_SHARED_FRAGS ;
@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
1107
1291
static void smap_gc_work (struct work_struct * w )
1108
1292
{
1109
1293
struct smap_psock_map_entry * e , * tmp ;
1294
+ struct sk_msg_buff * md , * mtmp ;
1110
1295
struct smap_psock * psock ;
1111
1296
1112
1297
psock = container_of (w , struct smap_psock , gc_work );
@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
1131
1316
kfree (psock -> cork );
1132
1317
}
1133
1318
1319
+ list_for_each_entry_safe (md , mtmp , & psock -> ingress , list ) {
1320
+ list_del (& md -> list );
1321
+ free_start_sg (psock -> sock , md );
1322
+ kfree (md );
1323
+ }
1324
+
1134
1325
list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
1135
1326
list_del (& e -> list );
1136
1327
kfree (e );
@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
1160
1351
INIT_WORK (& psock -> tx_work , smap_tx_work );
1161
1352
INIT_WORK (& psock -> gc_work , smap_gc_work );
1162
1353
INIT_LIST_HEAD (& psock -> maps );
1354
+ INIT_LIST_HEAD (& psock -> ingress );
1163
1355
refcount_set (& psock -> refcnt , 1 );
1164
1356
1165
1357
rcu_assign_sk_user_data (sock , psock );
0 commit comments