Skip to content

Commit 8a59f9d

Browse files
Cong WangAlexei Starovoitov
authored andcommitted
sock: Introduce sk->sk_prot->psock_update_sk_prot()
Currently sockmap calls into each protocol to update the struct proto and replace it. This certainly won't work when the protocol is implemented as a module, for example, AF_UNIX. Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each protocol can implement its own way to replace the struct proto. This also helps get rid of symbol dependencies on CONFIG_INET. Signed-off-by: Cong Wang <[email protected]> Signed-off-by: Alexei Starovoitov <[email protected]> Link: https://lore.kernel.org/bpf/[email protected]
1 parent a7ba455 commit 8a59f9d

File tree

12 files changed

+58
-45
lines changed

12 files changed

+58
-45
lines changed

include/linux/skmsg.h

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct sk_psock {
9999
void (*saved_close)(struct sock *sk, long timeout);
100100
void (*saved_write_space)(struct sock *sk);
101101
void (*saved_data_ready)(struct sock *sk);
102+
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
102103
struct proto *sk_proto;
103104
struct mutex work_mutex;
104105
struct sk_psock_work_state work_state;
@@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
395396
}
396397
}
397398

398-
static inline void sk_psock_update_proto(struct sock *sk,
399-
struct sk_psock *psock,
400-
struct proto *ops)
401-
{
402-
/* Pairs with lockless read in sk_clone_lock() */
403-
WRITE_ONCE(sk->sk_prot, ops);
404-
}
405-
406399
static inline void sk_psock_restore_proto(struct sock *sk,
407400
struct sk_psock *psock)
408401
{
409402
sk->sk_prot->unhash = psock->saved_unhash;
410-
if (inet_csk_has_ulp(sk)) {
411-
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
412-
} else {
413-
sk->sk_write_space = psock->saved_write_space;
414-
/* Pairs with lockless read in sk_clone_lock() */
415-
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
416-
}
403+
if (psock->psock_update_sk_prot)
404+
psock->psock_update_sk_prot(sk, true);
417405
}
418406

419407
static inline void sk_psock_set_state(struct sk_psock *psock,

include/net/sock.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,9 @@ struct proto {
11841184
void (*unhash)(struct sock *sk);
11851185
void (*rehash)(struct sock *sk);
11861186
int (*get_port)(struct sock *sk, unsigned short snum);
1187+
#ifdef CONFIG_BPF_SYSCALL
1188+
int (*psock_update_sk_prot)(struct sock *sk, bool restore);
1189+
#endif
11871190

11881191
/* Keeping track of sockets in use */
11891192
#ifdef CONFIG_PROC_FS

include/net/tcp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,6 +2203,7 @@ struct sk_psock;
22032203

22042204
#ifdef CONFIG_BPF_SYSCALL
22052205
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
2206+
int tcp_bpf_update_proto(struct sock *sk, bool restore);
22062207
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
22072208
#endif /* CONFIG_BPF_SYSCALL */
22082209

include/net/udp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
518518
#ifdef CONFIG_BPF_SYSCALL
519519
struct sk_psock;
520520
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
521+
int udp_bpf_update_proto(struct sock *sk, bool restore);
521522
#endif
522523

523524
#endif /* _UDP_H */

net/core/skmsg.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
562562

563563
write_lock_bh(&sk->sk_callback_lock);
564564

565-
if (inet_csk_has_ulp(sk)) {
566-
psock = ERR_PTR(-EINVAL);
567-
goto out;
568-
}
569-
570565
if (sk->sk_user_data) {
571566
psock = ERR_PTR(-EBUSY);
572567
goto out;

net/core/sock_map.c

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
185185

186186
static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
187187
{
188-
struct proto *prot;
189-
190-
switch (sk->sk_type) {
191-
case SOCK_STREAM:
192-
prot = tcp_bpf_get_proto(sk, psock);
193-
break;
194-
195-
case SOCK_DGRAM:
196-
prot = udp_bpf_get_proto(sk, psock);
197-
break;
198-
199-
default:
188+
if (!sk->sk_prot->psock_update_sk_prot)
200189
return -EINVAL;
201-
}
202-
203-
if (IS_ERR(prot))
204-
return PTR_ERR(prot);
205-
206-
sk_psock_update_proto(sk, psock, prot);
207-
return 0;
190+
psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
191+
return sk->sk_prot->psock_update_sk_prot(sk, false);
208192
}
209193

210194
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
@@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)
556540

557541
static bool sock_map_sk_is_suitable(const struct sock *sk)
558542
{
559-
return sk_is_tcp(sk) || sk_is_udp(sk);
543+
return !!sk->sk_prot->psock_update_sk_prot;
560544
}
561545

562546
static bool sock_map_sk_state_allowed(const struct sock *sk)

net/ipv4/tcp_bpf.c

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
595595
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
596596
}
597597

598-
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
598+
int tcp_bpf_update_proto(struct sock *sk, bool restore)
599599
{
600+
struct sk_psock *psock = sk_psock(sk);
600601
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
601602
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
602603

604+
if (restore) {
605+
if (inet_csk_has_ulp(sk)) {
606+
tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
607+
} else {
608+
sk->sk_write_space = psock->saved_write_space;
609+
/* Pairs with lockless read in sk_clone_lock() */
610+
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
611+
}
612+
return 0;
613+
}
614+
615+
if (inet_csk_has_ulp(sk))
616+
return -EINVAL;
617+
603618
if (sk->sk_family == AF_INET6) {
604619
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
605-
return ERR_PTR(-EINVAL);
620+
return -EINVAL;
606621

607622
tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
608623
}
609624

610-
return &tcp_bpf_prots[family][config];
625+
/* Pairs with lockless read in sk_clone_lock() */
626+
WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
627+
return 0;
611628
}
629+
EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
612630

613631
/* If a child got cloned from a listening socket that had tcp_bpf
614632
* protocol callbacks installed, we need to restore the callbacks to

net/ipv4/tcp_ipv4.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
28062806
.hash = inet_hash,
28072807
.unhash = inet_unhash,
28082808
.get_port = inet_csk_get_port,
2809+
#ifdef CONFIG_BPF_SYSCALL
2810+
.psock_update_sk_prot = tcp_bpf_update_proto,
2811+
#endif
28092812
.enter_memory_pressure = tcp_enter_memory_pressure,
28102813
.leave_memory_pressure = tcp_leave_memory_pressure,
28112814
.stream_memory_free = tcp_stream_memory_free,

net/ipv4/udp.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,9 @@ struct proto udp_prot = {
28492849
.unhash = udp_lib_unhash,
28502850
.rehash = udp_v4_rehash,
28512851
.get_port = udp_v4_get_port,
2852+
#ifdef CONFIG_BPF_SYSCALL
2853+
.psock_update_sk_prot = udp_bpf_update_proto,
2854+
#endif
28522855
.memory_allocated = &udp_memory_allocated,
28532856
.sysctl_mem = sysctl_udp_mem,
28542857
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),

net/ipv4/udp_bpf.c

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
4141
}
4242
core_initcall(udp_bpf_v4_build_proto);
4343

44-
struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
44+
int udp_bpf_update_proto(struct sock *sk, bool restore)
4545
{
4646
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
47+
struct sk_psock *psock = sk_psock(sk);
48+
49+
if (restore) {
50+
sk->sk_write_space = psock->saved_write_space;
51+
/* Pairs with lockless read in sk_clone_lock() */
52+
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
53+
return 0;
54+
}
4755

4856
if (sk->sk_family == AF_INET6)
4957
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
5058

51-
return &udp_bpf_prots[family];
59+
/* Pairs with lockless read in sk_clone_lock() */
60+
WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
61+
return 0;
5262
}
63+
EXPORT_SYMBOL_GPL(udp_bpf_update_proto);

net/ipv6/tcp_ipv6.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = {
21392139
.hash = inet6_hash,
21402140
.unhash = inet_unhash,
21412141
.get_port = inet_csk_get_port,
2142+
#ifdef CONFIG_BPF_SYSCALL
2143+
.psock_update_sk_prot = tcp_bpf_update_proto,
2144+
#endif
21422145
.enter_memory_pressure = tcp_enter_memory_pressure,
21432146
.leave_memory_pressure = tcp_leave_memory_pressure,
21442147
.stream_memory_free = tcp_stream_memory_free,

net/ipv6/udp.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
17131713
.unhash = udp_lib_unhash,
17141714
.rehash = udp_v6_rehash,
17151715
.get_port = udp_v6_get_port,
1716+
#ifdef CONFIG_BPF_SYSCALL
1717+
.psock_update_sk_prot = udp_bpf_update_proto,
1718+
#endif
17161719
.memory_allocated = &udp_memory_allocated,
17171720
.sysctl_mem = sysctl_udp_mem,
17181721
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),

0 commit comments

Comments
 (0)