Skip to content

Commit 8934ce2

Browse files
jrfastabborkmann
authored andcommitted
bpf: sockmap redirect ingress support
Add support for the BPF_F_INGRESS flag in sk_msg redirect helper. To do this add a scatterlist ring for receiving socks to check before calling into regular recvmsg call path. Additionally, because the poll wakeup logic only checked the skb recv queue we need to add a hook in TCP stack (similar to write side) so that we have a way to wake up polling socks when a scatterlist is redirected to that sock. After this all that is needed is for the redirect helper to push the scatterlist into the psock receive queue. Signed-off-by: John Fastabend <[email protected]> Signed-off-by: Daniel Borkmann <[email protected]>
1 parent 2252743 commit 8934ce2

File tree

5 files changed

+207
-5
lines changed

5 files changed

+207
-5
lines changed

include/linux/filter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ struct sk_msg_buff {
521521
__u32 key;
522522
__u32 flags;
523523
struct bpf_map *map;
524+
struct list_head list;
524525
};
525526

526527
/* Compute the linear packet data range [data, data_end) which

include/net/sock.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ struct proto {
10851085
#endif
10861086

10871087
bool (*stream_memory_free)(const struct sock *sk);
1088+
bool (*stream_memory_read)(const struct sock *sk);
10881089
/* Memory pressure */
10891090
void (*enter_memory_pressure)(struct sock *sk);
10901091
void (*leave_memory_pressure)(struct sock *sk);

kernel/bpf/sockmap.c

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
#include <linux/mm.h>
4242
#include <net/strparser.h>
4343
#include <net/tcp.h>
44+
#include <linux/ptr_ring.h>
45+
#include <net/inet_common.h>
4446

4547
#define SOCK_CREATE_FLAG_MASK \
4648
(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
@@ -82,6 +84,7 @@ struct smap_psock {
8284
int sg_size;
8385
int eval;
8486
struct sk_msg_buff *cork;
87+
struct list_head ingress;
8588

8689
struct strparser strp;
8790
struct bpf_prog *bpf_tx_msg;
@@ -103,6 +106,8 @@ struct smap_psock {
103106
};
104107

105108
static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
109+
static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
110+
int nonblock, int flags, int *addr_len);
106111
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
107112
static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
108113
int offset, size_t size, int flags);
@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
112117
return rcu_dereference_sk_user_data(sk);
113118
}
114119

120+
static bool bpf_tcp_stream_read(const struct sock *sk)
121+
{
122+
struct smap_psock *psock;
123+
bool empty = true;
124+
125+
rcu_read_lock();
126+
psock = smap_psock_sk(sk);
127+
if (unlikely(!psock))
128+
goto out;
129+
empty = list_empty(&psock->ingress);
130+
out:
131+
rcu_read_unlock();
132+
return !empty;
133+
}
134+
115135
static struct proto tcp_bpf_proto;
116136
static int bpf_tcp_init(struct sock *sk)
117137
{
@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
135155
if (psock->bpf_tx_msg) {
136156
tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
137157
tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
158+
tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg;
159+
tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read;
138160
}
139161

140162
sk->sk_prot = &tcp_bpf_proto;
@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
170192
{
171193
void (*close_fun)(struct sock *sk, long timeout);
172194
struct smap_psock_map_entry *e, *tmp;
195+
struct sk_msg_buff *md, *mtmp;
173196
struct smap_psock *psock;
174197
struct sock *osk;
175198

@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
188211
close_fun = psock->save_close;
189212

190213
write_lock_bh(&sk->sk_callback_lock);
214+
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
215+
list_del(&md->list);
216+
free_start_sg(psock->sock, md);
217+
kfree(md);
218+
}
219+
191220
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
192221
osk = cmpxchg(e->entry, sk, NULL);
193222
if (osk == sk) {
@@ -468,13 +497,80 @@ static unsigned int smap_do_tx_msg(struct sock *sk,
468497
return _rc;
469498
}
470499

500+
static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
501+
struct smap_psock *psock,
502+
struct sk_msg_buff *md, int flags)
503+
{
504+
bool apply = apply_bytes;
505+
size_t size, copied = 0;
506+
struct sk_msg_buff *r;
507+
int err = 0, i;
508+
509+
r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
510+
if (unlikely(!r))
511+
return -ENOMEM;
512+
513+
lock_sock(sk);
514+
r->sg_start = md->sg_start;
515+
i = md->sg_start;
516+
517+
do {
518+
r->sg_data[i] = md->sg_data[i];
519+
520+
size = (apply && apply_bytes < md->sg_data[i].length) ?
521+
apply_bytes : md->sg_data[i].length;
522+
523+
if (!sk_wmem_schedule(sk, size)) {
524+
if (!copied)
525+
err = -ENOMEM;
526+
break;
527+
}
528+
529+
sk_mem_charge(sk, size);
530+
r->sg_data[i].length = size;
531+
md->sg_data[i].length -= size;
532+
md->sg_data[i].offset += size;
533+
copied += size;
534+
535+
if (md->sg_data[i].length) {
536+
get_page(sg_page(&r->sg_data[i]));
537+
r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
538+
} else {
539+
i++;
540+
if (i == MAX_SKB_FRAGS)
541+
i = 0;
542+
r->sg_end = i;
543+
}
544+
545+
if (apply) {
546+
apply_bytes -= size;
547+
if (!apply_bytes)
548+
break;
549+
}
550+
} while (i != md->sg_end);
551+
552+
md->sg_start = i;
553+
554+
if (!err) {
555+
list_add_tail(&r->list, &psock->ingress);
556+
sk->sk_data_ready(sk);
557+
} else {
558+
free_start_sg(sk, r);
559+
kfree(r);
560+
}
561+
562+
release_sock(sk);
563+
return err;
564+
}
565+
471566
static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
472567
struct sk_msg_buff *md,
473568
int flags)
474569
{
475570
struct smap_psock *psock;
476571
struct scatterlist *sg;
477572
int i, err, free = 0;
573+
bool ingress = !!(md->flags & BPF_F_INGRESS);
478574

479575
sg = md->sg_data;
480576

@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
487583
goto out_rcu;
488584

489585
rcu_read_unlock();
490-
lock_sock(sk);
491-
err = bpf_tcp_push(sk, send, md, flags, false);
492-
release_sock(sk);
586+
587+
if (ingress) {
588+
err = bpf_tcp_ingress(sk, send, psock, md, flags);
589+
} else {
590+
lock_sock(sk);
591+
err = bpf_tcp_push(sk, send, md, flags, false);
592+
release_sock(sk);
593+
}
493594
smap_release_sock(psock, sk);
494595
if (unlikely(err))
495596
goto out;
@@ -623,6 +724,89 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
623724
return err;
624725
}
625726

727+
static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
728+
int nonblock, int flags, int *addr_len)
729+
{
730+
struct iov_iter *iter = &msg->msg_iter;
731+
struct smap_psock *psock;
732+
int copied = 0;
733+
734+
if (unlikely(flags & MSG_ERRQUEUE))
735+
return inet_recv_error(sk, msg, len, addr_len);
736+
737+
rcu_read_lock();
738+
psock = smap_psock_sk(sk);
739+
if (unlikely(!psock))
740+
goto out;
741+
742+
if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
743+
goto out;
744+
rcu_read_unlock();
745+
746+
if (!skb_queue_empty(&sk->sk_receive_queue))
747+
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
748+
749+
lock_sock(sk);
750+
while (copied != len) {
751+
struct scatterlist *sg;
752+
struct sk_msg_buff *md;
753+
int i;
754+
755+
md = list_first_entry_or_null(&psock->ingress,
756+
struct sk_msg_buff, list);
757+
if (unlikely(!md))
758+
break;
759+
i = md->sg_start;
760+
do {
761+
struct page *page;
762+
int n, copy;
763+
764+
sg = &md->sg_data[i];
765+
copy = sg->length;
766+
page = sg_page(sg);
767+
768+
if (copied + copy > len)
769+
copy = len - copied;
770+
771+
n = copy_page_to_iter(page, sg->offset, copy, iter);
772+
if (n != copy) {
773+
md->sg_start = i;
774+
release_sock(sk);
775+
smap_release_sock(psock, sk);
776+
return -EFAULT;
777+
}
778+
779+
copied += copy;
780+
sg->offset += copy;
781+
sg->length -= copy;
782+
sk_mem_uncharge(sk, copy);
783+
784+
if (!sg->length) {
785+
i++;
786+
if (i == MAX_SKB_FRAGS)
787+
i = 0;
788+
put_page(page);
789+
}
790+
if (copied == len)
791+
break;
792+
} while (i != md->sg_end);
793+
md->sg_start = i;
794+
795+
if (!sg->length && md->sg_start == md->sg_end) {
796+
list_del(&md->list);
797+
kfree(md);
798+
}
799+
}
800+
801+
release_sock(sk);
802+
smap_release_sock(psock, sk);
803+
return copied;
804+
out:
805+
rcu_read_unlock();
806+
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
807+
}
808+
809+
626810
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
627811
{
628812
int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
11071291
static void smap_gc_work(struct work_struct *w)
11081292
{
11091293
struct smap_psock_map_entry *e, *tmp;
1294+
struct sk_msg_buff *md, *mtmp;
11101295
struct smap_psock *psock;
11111296

11121297
psock = container_of(w, struct smap_psock, gc_work);
@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
11311316
kfree(psock->cork);
11321317
}
11331318

1319+
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1320+
list_del(&md->list);
1321+
free_start_sg(psock->sock, md);
1322+
kfree(md);
1323+
}
1324+
11341325
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
11351326
list_del(&e->list);
11361327
kfree(e);
@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
11601351
INIT_WORK(&psock->tx_work, smap_tx_work);
11611352
INIT_WORK(&psock->gc_work, smap_gc_work);
11621353
INIT_LIST_HEAD(&psock->maps);
1354+
INIT_LIST_HEAD(&psock->ingress);
11631355
refcount_set(&psock->refcnt, 1);
11641356

11651357
rcu_assign_sk_user_data(sock, psock);

net/core/filter.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1894,7 +1894,7 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
18941894
struct bpf_map *, map, u32, key, u64, flags)
18951895
{
18961896
/* If user passes invalid input drop the packet. */
1897-
if (unlikely(flags))
1897+
if (unlikely(flags & ~(BPF_F_INGRESS)))
18981898
return SK_DROP;
18991899

19001900
msg->key = key;

net/ipv4/tcp.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,14 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
485485
}
486486
}
487487

488+
static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
489+
int target, struct sock *sk)
490+
{
491+
return (tp->rcv_nxt - tp->copied_seq >= target) ||
492+
(sk->sk_prot->stream_memory_read ?
493+
sk->sk_prot->stream_memory_read(sk) : false);
494+
}
495+
488496
/*
489497
* Wait for a TCP event.
490498
*
@@ -554,7 +562,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
554562
tp->urg_data)
555563
target++;
556564

557-
if (tp->rcv_nxt - tp->copied_seq >= target)
565+
if (tcp_stream_is_readable(tp, target, sk))
558566
mask |= EPOLLIN | EPOLLRDNORM;
559567

560568
if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {

0 commit comments

Comments
 (0)